diff --git a/docs/persistence.rst b/docs/persistence.rst index 7866ea9d..7fced29b 100644 --- a/docs/persistence.rst +++ b/docs/persistence.rst @@ -85,7 +85,7 @@ In contrast to ``pickle``, skops cannot persist arbitrary Python code. This means if you have custom functions (say, a custom function to be used with :class:`sklearn.preprocessing.FunctionTransformer`), it will not work. However, most ``numpy`` and ``scipy`` functions should work. Therefore, you can actually -save built-in functions like``numpy.sqrt``. +save built-in functions like ``numpy.sqrt``. Roadmap ------- diff --git a/skops/io/_audit.py b/skops/io/_audit.py index 9bda258c..6063dd21 100644 --- a/skops/io/_audit.py +++ b/skops/io/_audit.py @@ -1,7 +1,20 @@ -from skops.io.exceptions import UntrustedTypesFoundException +from __future__ import annotations +import io +from contextlib import contextmanager +from typing import Any, Generator, Sequence -def check_type(module_name, type_name, trusted): +from ..utils.fixes import Literal +from ._trusted_types import PRIMITIVE_TYPE_NAMES +from ._utils import LoadContext, get_module +from .exceptions import UntrustedTypesFoundException + +NODE_TYPE_MAPPING = {} # type: ignore + + +def check_type( + module_name: str, type_name: str, trusted: Literal[True] | Sequence[str] +) -> bool: """Check if a type is safe to load. A type is safe to load only if it's present in the trusted list. @@ -14,7 +27,7 @@ def check_type(module_name, type_name, trusted): type_name : str The class name of the type. - trusted : bool, or list of str + trusted : True, or list of str If ``True``, the tree is considered safe. Otherwise trusted has to be a list of trusted types. @@ -28,7 +41,7 @@ def check_type(module_name, type_name, trusted): return module_name + "." + type_name in trusted -def audit_tree(tree, trusted): +def audit_tree(tree: Node, trusted: bool | Sequence[str]) -> None: """Audit a tree of nodes. A tree is safe if it only contains trusted types. Audit is skipped if @@ -39,7 +52,7 @@ def audit_tree(tree, trusted): tree : skops.io._dispatch.Node The tree to audit. - trusted : bool, or list of str + trusted : True, or list of str If ``True``, the tree is considered safe. Otherwise trusted has to be a list of trusted types names. @@ -59,3 +72,290 @@ def audit_tree(tree, trusted): unsafe -= set(trusted) if unsafe: raise UntrustedTypesFoundException(unsafe) + + +class UNINITIALIZED: + """Sentinel value to indicate that a value has not been initialized yet.""" + + +# Node: types for Generator mean: YieldType, SendType, ReturnType +@contextmanager +def temp_setattr(obj: Any, **kwargs: Any) -> Generator[None, None, None]: + """Context manager to temporarily set attributes on an object.""" + existing_attrs = {k for k in kwargs.keys() if hasattr(obj, k)} + previous_values = {k: getattr(obj, k, None) for k in kwargs} + for k, v in kwargs.items(): + setattr(obj, k, v) + try: + yield + finally: + for k, v in previous_values.items(): + if k in existing_attrs: + setattr(obj, k, v) + else: + delattr(obj, k) + + +class Node: + """A node in the tree of objects. + + This class is a parent class for all nodes in the tree of objects. Each + type of object (e.g. dict, list, etc.) has its own subclass of Node. + + Each child class has to implement two methods: ``__init__`` and + ``_construct``. + + ``__init__`` takes care of traversing the state tree and to create the + corresponding ``Node`` objects. It has access to the ``load_context`` which + in turn has access to the source zip file. The child class's ``__init__`` + must load attributes into the ``children`` attribute, which is a + dictionary of ``{child_name: unloaded_value/Node/list/etc}``. The + ``get_unsafe_set`` should be able to parse and validate the values set + under the ``children`` attribute. Note that primitives are persisted as a + ``JsonNode``. + + ``_construct`` takes care of constructing the object. It is only called + once and the result is cached in ``construct`` which is implemented in this + class. All required data to construct an instance should be loaded during + ``__init__``. + + The separation of ``__init__`` and ``_construct`` is necessary because + audit methods are called after ``__init__`` and before ``construct``. + Therefore ``__init__`` should avoid creating any instances or importing + any modules, to avoid running potentially untrusted code. + + Parameters + ---------- + state : dict + A dict representing the state of the dumped object. + + load_context : LoadContext + The context of the loading process. + + trusted : bool or list of str, default=False + If ``True``, the object will be loaded without any security checks. If + ``False``, the object will be loaded only if there are only trusted + objects in the dumped file. If a list of strings, the object will be + loaded only if all of its required types are listed in ``trusted`` + or are trusted by default. + + memoize : bool, default=True + If ``True``, the object will be memoized in the load context, if it has + the ``__id__`` set. This is used to avoid loading the same object + multiple times. + """ + + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: bool | Sequence[str] = False, + memoize: bool = True, + ) -> None: + self.class_name, self.module_name = state["__class__"], state["__module__"] + self._is_safe = None + self._constructed = UNINITIALIZED + saved_id = state.get("__id__") + if saved_id and memoize: + # hold reference to obj in case same instance encountered again in + # save state + load_context.memoize(self, saved_id) + + # subclasses should always: + # 1. call super().__init__() + # 2. set self.trusted = self._get_trusted(trusted, ...) where ... is a + # list of appropriate trusted types + # 3. set self.children, where children are states of child nodes; do not + # construct the children objects yet + self.trusted = self._get_trusted(trusted, []) + self.children: dict[str, Any] = {} + + def construct(self): + """Construct the object. + + We only construct the object once, and then cache the result. + """ + if self._constructed is not UNINITIALIZED: + return self._constructed + self._constructed = self._construct() + return self._constructed + + def _construct(self): + raise NotImplementedError( + f"{self.__class__.__name__} should implement a '_construct' method" + ) + + @staticmethod + def _get_trusted( + trusted: bool | Sequence[str], default: list[str] + ) -> Literal[True] | list[str]: + """Return a trusted list, or True. + + If ``trusted`` is ``False``, we return the ``default``, otherwise the + ``trusted`` value is used. + + This is a convenience method called by child classes. + """ + if trusted is True: + # if trusted is True, we trust the node + return True + + if trusted is False: + # if trusted is False, we only trust the defaults + return default + + # otherwise we trust the given list, call list in case it's a tuple + return list(trusted) + + def is_self_safe(self) -> bool: + """True only if the node's type is considered safe. + + This property only checks the type of the node, not its children. + """ + return check_type(self.module_name, self.class_name, self.trusted) + + def is_safe(self) -> bool: + """True only if the node and all its children are safe.""" + # if trusted is set to True, we don't do any safety checks. + if self.trusted is True: + return True + + return len(self.get_unsafe_set()) == 0 + + def get_unsafe_set(self) -> set[str]: + """Get the set of unsafe types. + + This method returns all types which are not trusted, including this + node and all its children. + + Returns + ------- + unsafe_set : set + A set of unsafe types. + """ + if hasattr(self, "_computing_unsafe_set"): + # this means we're already computing this node's unsafe set, so we + # return an empty set and let the computation of the parent node + # continue. This is to avoid infinite recursion. + return set() + + with temp_setattr(self, _computing_unsafe_set=True): + res = set() + if not self.is_self_safe(): + res.add(self.module_name + "." + self.class_name) + + for child in self.children.values(): + if child is None: + continue + + # Get the safety set based on the type of the child. In most cases + # other than ListNode and DictNode, children are all of type Node. + if isinstance(child, list): + # iterate through the list + for value in child: + res.update(value.get_unsafe_set()) + elif isinstance(child, dict): + # iterate through the values of the dict only + # TODO: should we check the types of the keys? + for value in child.values(): + res.update(value.get_unsafe_set()) + elif isinstance(child, Node): + # delegate to the child Node + res.update(child.get_unsafe_set()) + elif type(child) is type: + # the if condition bellow is not merged with the previous + # one because if the above condition is True, the following + # conditions about BytesIO, etc should be ignored. + if not check_type(get_module(child), child.__name__, self.trusted): + # if the child is a type, we check its safety + res.add(get_module(child) + "." + child.__name__) + elif isinstance(child, io.BytesIO): + # We trust BytesIO objects, which are read by other + # libraries such as numpy, scipy. + continue + elif check_type( + get_module(child), child.__class__.__name__, PRIMITIVE_TYPE_NAMES + ): + # if the child is a primitive type, we don't need to check its + # safety. + continue + else: + raise ValueError( + f"Cannot determine the safety of type {type(child)}. Please" + " open an issue at https://github.com/skops-dev/skops/issues" + " for us to fix the issue." + ) + + return res + + +class CachedNode(Node): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: bool = False, + ): + # we pass memoize as False because we don't want to memoize the cached + # node. + super().__init__(state, load_context, trusted, memoize=False) + self.trusted = True + # TODO: deal with case that __id__ is unknown or prevent it from + # happening + self.cached = load_context.get_object(state.get("__id__")) # type: ignore + self.children = {} # type: ignore + + def _construct(self): + # TODO: FIXME This causes a recursion error when loading a cached + # object if we call the cached object's `construct``. Some refactoring + # is needed to fix this. + return self.cached.construct() + + +NODE_TYPE_MAPPING["CachedNode"] = CachedNode + + +def get_tree(state: dict[str, Any], load_context: LoadContext) -> Node: + """Get the tree of nodes. + + This function returns the root node of the tree of nodes. The tree is + constructed recursively by traversing the state tree. No instances are + created during this process. One would need to call ``construct`` on the + root node to create the instances. + + This function also handles memoization of the nodes. If a node has already + been created, it is returned instead of creating a new one. + + Parameters + ---------- + state : dict + The state of the dumped object. + + load_context : LoadContext + The context of the loading process. + + Returns + ------- + loaded_tree : Node + The tree containing all its (non-instantiated) child nodes. + """ + saved_id = state.get("__id__") + if saved_id in load_context.memo: + # This means the node is already loaded, so we return it. Note that the + # node is not constructed at this point. It will be constructed when + # the parent node's ``construct`` method is called, and for this node + # it'll be called more than once. But that's not an issue since the + # node's ``construct`` method caches the instance. + return load_context.get_object(saved_id) + + try: + node_cls = NODE_TYPE_MAPPING[state["__loader__"]] + except KeyError: + type_name = f"{state['__module__']}.{state['__class__']}" + raise TypeError( + f" Can't find loader {state['__loader__']} for type {type_name}." + ) + + loaded_tree = node_cls(state, load_context, trusted=False) # type: ignore + + return loaded_tree diff --git a/skops/io/_dispatch.py b/skops/io/_dispatch.py deleted file mode 100644 index 179bd4b2..00000000 --- a/skops/io/_dispatch.py +++ /dev/null @@ -1,263 +0,0 @@ -from __future__ import annotations - -import io -from contextlib import contextmanager - -from ._audit import check_type -from ._trusted_types import PRIMITIVE_TYPE_NAMES -from ._utils import LoadContext, get_module - -NODE_TYPE_MAPPING = {} # type: ignore - - -class UNINITIALIZED: - """Sentinel value to indicate that a value has not been initialized yet.""" - - -@contextmanager -def temp_setattr(obj, **kwargs): - """Context manager to temporarily set attributes on an object.""" - existing_attrs = {k for k in kwargs.keys() if hasattr(obj, k)} - previous_values = {k: getattr(obj, k, None) for k in kwargs} - for k, v in kwargs.items(): - setattr(obj, k, v) - try: - yield - finally: - for k, v in previous_values.items(): - if k in existing_attrs: - setattr(obj, k, v) - else: - delattr(obj, k) - - -class Node: - """A node in the tree of objects. - - This class is a parent class for all nodes in the tree of objects. Each - type of object (e.g. dict, list, etc.) has its own subclass of Node. - - Each child class has to implement two methods: ``__init__`` and - ``_construct``. - - ``__init__`` takes care of traversing the state tree and to create the - corresponding ``Node`` objects. It has access to the ``load_context`` which - in turn has access to the source zip file. The child class's ``__init__`` - must load attributes into the ``children`` attribute, which is a - dictionary of ``{child_name: unloaded_value/Node/list/etc}``. The - ``get_unsafe_set`` should be able to parse and validate the values set - under the ``children`` attribute. Note that primitives are persisted as a - ``JsonNode``. - - ``_construct`` takes care of constructing the object. It is only called - once and the result is cached in ``construct`` which is implemented in this - class. All required data to construct an instance should be loaded during - ``__init__``. - - The separation of ``__init__`` and ``_construct`` is necessary because - audit methods are called after ``__init__`` and before ``construct``. - Therefore ``__init__`` should avoid creating any instances or importing - any modules, to avoid running potentially untrusted code. - - Parameters - ---------- - state : dict - A dict representing the state of the dumped object. - - load_context : LoadContext - The context of the loading process. - - trusted : bool or list of str, default=False - If ``True``, the object will be loaded without any security checks. If - ``False``, the object will be loaded only if there are only trusted - objects in the dumped file. If a list of strings, the object will be - loaded only if all of its required types are listed in ``trusted`` - or are trusted by default. - - memoize : bool, default=True - If ``True``, the object will be memoized in the load context, if it has - the ``__id__`` set. This is used to avoid loading the same object - multiple times. - """ - - def __init__(self, state, load_context: LoadContext, trusted=False, memoize=True): - self.class_name, self.module_name = state["__class__"], state["__module__"] - self.trusted = trusted - self._is_safe = None - self._constructed = UNINITIALIZED - saved_id = state.get("__id__") - if saved_id and memoize: - # hold reference to obj in case same instance encountered again in - # save state - load_context.memoize(self, saved_id) - - def construct(self): - """Construct the object. - - We only construct the object once, and then cache the result. - """ - if self._constructed is not UNINITIALIZED: - return self._constructed - self._constructed = self._construct() - return self._constructed - - @staticmethod - def _get_trusted(trusted, default): - """Return a trusted list, or True. - - If ``trusted`` is ``False``, we return the ``default``, otherwise the - ``trusted`` value is used. - - This is a convenience method called by child classes. - """ - if trusted is True: - # if trusted is True, we trust the node - return True - - if trusted is False: - # if trusted is False, we only trust the defaults - return default - - # otherwise we trust the given list - return trusted - - def is_self_safe(self): - """True only if the node's type is considered safe. - - This property only checks the type of the node, not its children. - """ - return check_type(self.module_name, self.class_name, self.trusted) - - def is_safe(self): - """True only if the node and all its children are safe.""" - # if trusted is set to True, we don't do any safety checks. - if self.trusted is True: - return True - - return len(self.get_unsafe_set()) == 0 - - def get_unsafe_set(self): - """Get the set of unsafe types. - - This method returns all types which are not trusted, including this - node and all its children. - - Returns - ------- - unsafe_set : set - A set of unsafe types. - """ - if hasattr(self, "_computing_unsafe_set"): - # this means we're already computing this node's unsafe set, so we - # return an empty set and let the computation of the parent node - # continue. This is to avoid infinite recursion. - return set() - - with temp_setattr(self, _computing_unsafe_set=True): - res = set() - if not self.is_self_safe(): - res.add(self.module_name + "." + self.class_name) - - for child in self.children.values(): - if child is None: - continue - - # Get the safety set based on the type of the child. In most cases - # other than ListNode and DictNode, children are all of type Node. - if isinstance(child, list): - # iterate through the list - for value in child: - res.update(value.get_unsafe_set()) - elif isinstance(child, dict): - # iterate through the values of the dict only - # TODO: should we check the types of the keys? - for value in child.values(): - res.update(value.get_unsafe_set()) - elif isinstance(child, Node): - # delegate to the child Node - res.update(child.get_unsafe_set()) - elif type(child) is type: - # the if condition bellow is not merged with the previous - # one because if the above condition is True, the following - # conditions about BytesIO, etc should be ignored. - if not check_type(get_module(child), child.__name__, self.trusted): - # if the child is a type, we check its safety - res.add(get_module(child) + "." + child.__name__) - elif isinstance(child, io.BytesIO): - # We trust BytesIO objects, which are read by other - # libraries such as numpy, scipy. - continue - elif check_type( - get_module(child), child.__class__.__name__, PRIMITIVE_TYPE_NAMES - ): - # if the child is a primitive type, we don't need to check its - # safety. - continue - else: - raise ValueError( - f"Cannot determine the safety of type {type(child)}. Please" - " open an issue at https://github.com/skops-dev/skops/issues" - " for us to fix the issue." - ) - - return res - - -class CachedNode(Node): - def __init__(self, state, load_context: LoadContext, trusted=False): - # we pass memoize as False because we don't want to memoize the cached - # node. - super().__init__(state, load_context, trusted, memoize=False) - self.trusted = True - self.cached = load_context.get_object(state.get("__id__")) - self.children = {} # type: ignore - - def _construct(self): - # TODO: FIXME This causes a recursion error when loading a cached - # object if we call the cached object's `construct``. Some refactoring - # is needed to fix this. - return self.cached.construct() - - -NODE_TYPE_MAPPING["CachedNode"] = CachedNode - - -def get_tree(state, load_context: LoadContext): - """Get the tree of nodes. - - This function returns the root node of the tree of nodes. The tree is - constructed recursively by traversing the state tree. No instances are - created during this process. One would need to call ``construct`` on the - root node to create the instances. - - This function also handles memoization of the nodes. If a node has already - been created, it is returned instead of creating a new one. - - Parameters - ---------- - state : dict - The state of the dumped object. - - load_context : LoadContext - The context of the loading process. - """ - saved_id = state.get("__id__") - if saved_id in load_context.memo: - # This means the node is already loaded, so we return it. Note that the - # node is not constructed at this point. It will be constructed when - # the parent node's ``construct`` method is called, and for this node - # it'll be called more than once. But that's not an issue since the - # node's ``construct`` method caches the instance. - return load_context.get_object(saved_id) - - try: - node_cls = NODE_TYPE_MAPPING[state["__loader__"]] - except KeyError: - type_name = f"{state['__module__']}.{state['__class__']}" - raise TypeError( - f" Can't find loader {state['__loader__']} for type {type_name}." - ) - - loaded_tree = node_cls(state, load_context, trusted=False) # type: ignore - - return loaded_tree diff --git a/skops/io/_general.py b/skops/io/_general.py index b6449f0d..940f8273 100644 --- a/skops/io/_general.py +++ b/skops/io/_general.py @@ -3,11 +3,11 @@ import json from functools import partial from types import FunctionType, MethodType -from typing import Any +from typing import Any, Sequence import numpy as np -from ._dispatch import Node, get_tree +from ._audit import Node, get_tree from ._trusted_types import PRIMITIVE_TYPE_NAMES from ._utils import ( LoadContext, @@ -42,7 +42,12 @@ def dict_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: class DictNode(Node): - def __init__(self, state, load_context: LoadContext, trusted=False): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: bool | Sequence[str] = False, + ) -> None: super().__init__(state, load_context, trusted) self.trusted = self._get_trusted(trusted, ["builtins.dict"]) self.children = { @@ -74,7 +79,12 @@ def list_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: class ListNode(Node): - def __init__(self, state, load_context: LoadContext, trusted=False): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: bool | Sequence[str] = False, + ) -> None: super().__init__(state, load_context, trusted) self.trusted = self._get_trusted(trusted, ["builtins.list"]) self.children = { @@ -98,7 +108,12 @@ def set_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: class SetNode(Node): - def __init__(self, state, load_context: LoadContext, trusted=False): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: bool | Sequence[str] = False, + ) -> None: super().__init__(state, load_context, trusted) self.trusted = self._get_trusted(trusted, ["builtins.set"]) self.children = { @@ -122,7 +137,12 @@ def tuple_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: class TupleNode(Node): - def __init__(self, state, load_context: LoadContext, trusted=False): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: bool | Sequence[str] = False, + ) -> None: super().__init__(state, load_context, trusted) self.trusted = self._get_trusted(trusted, ["builtins.tuple"]) self.children = { @@ -139,7 +159,7 @@ def _construct(self): return cls(*content) return content - def isnamedtuple(self, t): + def isnamedtuple(self, t) -> bool: # This is needed since namedtuples need to have the args when # initialized. b = t.__bases__ @@ -165,7 +185,12 @@ def function_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: class FunctionNode(Node): - def __init__(self, state, load_context: LoadContext, trusted=False): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: bool | Sequence[str] = False, + ) -> None: super().__init__(state, load_context, trusted) # TODO: what do we trust? self.trusted = self._get_trusted(trusted, []) @@ -177,18 +202,15 @@ def _construct(self): self.children["content"]["function"], ) - def _get_function_name(self): + def _get_function_name(self) -> str: return ( self.children["content"]["module_path"] + "." + self.children["content"]["function"] ) - def is_safe(self): - return self._get_function_name() in self.trusted - - def get_unsafe_set(self): - if self.is_safe(): + def get_unsafe_set(self) -> set[str]: + if self.trusted is True: return set() return {self._get_function_name()} @@ -211,7 +233,12 @@ def partial_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: class PartialNode(Node): - def __init__(self, state, load_context: LoadContext, trusted=False): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: bool | Sequence[str] = False, + ) -> None: super().__init__(state, load_context, trusted) # TODO: should we trust anything? self.trusted = self._get_trusted(trusted, []) @@ -228,7 +255,8 @@ def _construct(self): kwds = self.children["kwds"].construct() namespace = self.children["namespace"].construct() instance = partial(func, *args, **kwds) # always use partial, not a subclass - instance.__setstate__((func, args, kwds, namespace)) + # partial always has __setstate__ + instance.__setstate__((func, args, kwds, namespace)) # type: ignore return instance @@ -244,13 +272,18 @@ def type_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: class TypeNode(Node): - def __init__(self, state, load_context: LoadContext, trusted=False): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: bool | Sequence[str] = False, + ) -> None: super().__init__(state, load_context, trusted) # TODO: what do we trust? self.trusted = self._get_trusted(trusted, PRIMITIVE_TYPE_NAMES) # We use a bare Node type here since a Node only checks the type in the # dict using __class__ and __module__ keys. - self.children = {} # type: ignore + self.children = {} def _construct(self): return _import_obj(self.module_name, self.class_name) @@ -271,7 +304,12 @@ def slice_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: class SliceNode(Node): - def __init__(self, state, load_context: LoadContext, trusted=False): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: bool | Sequence[str] = False, + ) -> None: super().__init__(state, load_context, trusted) self.trusted = self._get_trusted(trusted, ["builtins.slice"]) self.children = { @@ -329,11 +367,17 @@ def object_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: class ObjectNode(Node): - def __init__(self, state, load_context: LoadContext, trusted=False): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: bool | Sequence[str] = False, + ) -> None: super().__init__(state, load_context, trusted) - if "content" in state: - attrs = get_tree(state.get("content"), load_context) + content = state.get("content") + if content is not None: + attrs = get_tree(content, load_context) else: attrs = None @@ -345,9 +389,10 @@ def _construct(self): cls = gettype(self.module_name, self.class_name) # Instead of simply constructing the instance, we use __new__, which - # bypasses the __init__, and then we set the attributes. This solves - # the issue of required init arguments. - instance = cls.__new__(cls) + # bypasses the __init__, and then we set the attributes. This solves the + # issue of required init arguments. Note that the instance created here + # might not be valid until all its attributes have been set below. + instance = cls.__new__(cls) # type: ignore if not self.children["attrs"]: # nothing more to do @@ -362,7 +407,7 @@ def _construct(self): return instance -def method_get_state(obj: Any, save_context: SaveContext): +def method_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: # This method is used to persist bound methods, which are # dependent on a specific instance of an object. # It stores the state of the object the method is bound to, @@ -380,7 +425,12 @@ def method_get_state(obj: Any, save_context: SaveContext): class MethodNode(Node): - def __init__(self, state, load_context: LoadContext, trusted=False): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: bool | Sequence[str] = False, + ) -> None: super().__init__(state, load_context, trusted) self.children = { "obj": get_tree(state["content"]["obj"], load_context), @@ -400,19 +450,25 @@ def unsupported_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any] class JsonNode(Node): - def __init__(self, state, load_context: LoadContext, trusted=False): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: bool | Sequence[str] = False, + ) -> None: super().__init__(state, load_context, trusted) self.content = state["content"] + self.children = {} - def is_safe(self): + def is_safe(self) -> bool: # JsonNode is always considered safe. # TODO: should we consider a JsonNode always safe? return True - def is_self_safe(self): + def is_self_safe(self) -> bool: return True - def get_unsafe_set(self): + def get_unsafe_set(self) -> set[str]: return set() def _construct(self): diff --git a/skops/io/_numpy.py b/skops/io/_numpy.py index cc31e6c8..4676e5f0 100644 --- a/skops/io/_numpy.py +++ b/skops/io/_numpy.py @@ -1,11 +1,11 @@ from __future__ import annotations import io -from typing import Any +from typing import Any, Sequence import numpy as np -from ._dispatch import Node, get_tree +from ._audit import Node, get_tree from ._utils import LoadContext, SaveContext, get_module, get_state, gettype from .exceptions import UnsupportedTypeException @@ -50,7 +50,12 @@ def ndarray_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: class NdArrayNode(Node): - def __init__(self, state, load_context: LoadContext, trusted=False): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: bool | Sequence[str] = False, + ) -> None: super().__init__(state, load_context, trusted) self.type = state["type"] self.trusted = self._get_trusted(trusted, ["numpy.ndarray"]) @@ -110,7 +115,12 @@ def maskedarray_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any] class MaskedArrayNode(Node): - def __init__(self, state, load_context: LoadContext, trusted=False): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: bool | Sequence[str] = False, + ) -> None: super().__init__(state, load_context, trusted) self.trusted = self._get_trusted(trusted, ["numpy.ma.MaskedArray"]) self.children = { @@ -136,7 +146,12 @@ def random_state_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any class RandomStateNode(Node): - def __init__(self, state, load_context: LoadContext, trusted=False): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: bool | Sequence[str] = False, + ) -> None: super().__init__(state, load_context, trusted) self.children = {"content": get_tree(state["content"], load_context)} self.trusted = self._get_trusted(trusted, ["numpy.random.RandomState"]) @@ -159,7 +174,12 @@ def random_generator_get_state(obj: Any, save_context: SaveContext) -> dict[str, class RandomGeneratorNode(Node): - def __init__(self, state, load_context: LoadContext, trusted=False): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: bool | Sequence[str] = False, + ) -> None: super().__init__(state, load_context, trusted) self.children = {"bit_generator_state": state["content"]["bit_generator"]} self.trusted = self._get_trusted(trusted, ["numpy.random.Generator"]) @@ -205,7 +225,12 @@ def dtype_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: class DTypeNode(Node): - def __init__(self, state, load_context: LoadContext, trusted=False): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: bool | Sequence[str] = False, + ) -> None: super().__init__(state, load_context, trusted) self.children = {"content": get_tree(state["content"], load_context)} # TODO: what should we trust? diff --git a/skops/io/_persist.py b/skops/io/_persist.py index 4d039936..69952161 100644 --- a/skops/io/_persist.py +++ b/skops/io/_persist.py @@ -3,12 +3,13 @@ import importlib import io import json +from pathlib import Path +from typing import Any, Sequence from zipfile import ZipFile import skops -from ._audit import audit_tree -from ._dispatch import NODE_TYPE_MAPPING, get_tree +from ._audit import NODE_TYPE_MAPPING, audit_tree, get_tree from ._utils import LoadContext, SaveContext, _get_state, get_state # We load the dispatch functions from the corresponding modules and register @@ -23,7 +24,7 @@ NODE_TYPE_MAPPING.update(module.NODE_TYPE_MAPPING) -def _save(obj): +def _save(obj: Any) -> io.BytesIO: buffer = io.BytesIO() with ZipFile(buffer, "w") as zip_file: @@ -38,7 +39,7 @@ def _save(obj): return buffer -def dump(obj, file): +def dump(obj: Any, file: str) -> None: """Save an object using the skops persistence format. Skops aims at providing a secure persistence feature that does not rely on @@ -68,7 +69,7 @@ def dump(obj, file): f.write(buffer.getbuffer()) -def dumps(obj): +def dumps(obj: Any) -> bytes: """Save an object using the skops persistence format as a bytes object. .. warning:: @@ -88,7 +89,7 @@ def dumps(obj): return buffer.getbuffer().tobytes() -def load(file, trusted=False): +def load(file: str | Path, trusted: bool | Sequence[str] = False) -> Any: """Load an object saved with the skops persistence format. Skops aims at providing a secure persistence feature that does not rely on @@ -104,7 +105,7 @@ def load(file, trusted=False): Parameters ---------- - file: str + file: str or pathlib.Path The file name of the object to be loaded. trusted: bool, or list of str, default=False @@ -130,7 +131,7 @@ def load(file, trusted=False): return instance -def loads(data, trusted=False): +def loads(data: bytes, trusted: bool | Sequence[str] = False) -> Any: """Load an object saved with the skops persistence format from a bytes object. @@ -171,7 +172,9 @@ def loads(data, trusted=False): return instance -def get_untrusted_types(*, data=None, file=None): +def get_untrusted_types( + *, data: bytes | None = None, file: str | Path | None = None +) -> list[str]: """Get a list of untrusted types in a skops dump. Parameters @@ -193,11 +196,15 @@ def get_untrusted_types(*, data=None, file=None): """ if data and file: raise ValueError("Only one of data or file should be passed.") + if not data and not file: + raise ValueError("Exactly one of data or file should be passed.") + content: io.BytesIO | str | Path if data: content = io.BytesIO(data) else: - content = file + # mypy doesn't understand that file cannot be None here, thus ignore + content = file # type: ignore with ZipFile(content, "r") as zip_file: schema = json.loads(zip_file.read("schema.json")) diff --git a/skops/io/_scipy.py b/skops/io/_scipy.py index 223be95d..0e240b87 100644 --- a/skops/io/_scipy.py +++ b/skops/io/_scipy.py @@ -1,11 +1,11 @@ from __future__ import annotations import io -from typing import Any +from typing import Any, Sequence from scipy.sparse import load_npz, save_npz, spmatrix -from ._dispatch import Node +from ._audit import Node from ._utils import LoadContext, SaveContext, get_module @@ -33,7 +33,12 @@ def sparse_matrix_get_state(obj: Any, save_context: SaveContext) -> dict[str, An class SparseMatrixNode(Node): - def __init__(self, state, load_context: LoadContext, trusted=False): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: bool | Sequence[str] = False, + ) -> None: super().__init__(state, load_context, trusted) type = state["type"] self.trusted = self._get_trusted(trusted, ["scipy.sparse.spmatrix"]) diff --git a/skops/io/_sklearn.py b/skops/io/_sklearn.py index 4a57eea0..9da5392a 100644 --- a/skops/io/_sklearn.py +++ b/skops/io/_sklearn.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any +from typing import Any, Callable, Sequence, Type from sklearn.cluster import Birch @@ -22,7 +22,7 @@ ) from sklearn.tree._tree import Tree -from ._dispatch import Node, get_tree +from ._audit import Node, get_tree from ._general import unsupported_get_state from ._utils import LoadContext, SaveContext, get_module, get_state, gettype from .exceptions import UnsupportedTypeException @@ -88,7 +88,13 @@ def reduce_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: class ReduceNode(Node): - def __init__(self, state, load_context: LoadContext, constructor, trusted=False): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + constructor: Type[Any] | Callable[..., Any], + trusted: bool | Sequence[str] = False, + ) -> None: super().__init__(state, load_context, trusted) reduce = state["__reduce__"] self.children = { @@ -99,7 +105,8 @@ def __init__(self, state, load_context: LoadContext, constructor, trusted=False) def _construct(self): args = self.children["args"].construct() - instance = self.children["constructor"](*args) + constructor = self.children["constructor"] + instance = constructor(*args) attrs = self.children["attrs"].construct() if not attrs: # nothing more to do @@ -107,7 +114,7 @@ def _construct(self): if isinstance(args, tuple) and not hasattr(instance, "__setstate__"): raise UnsupportedTypeException( - f"Objects of type {self.constructor} are not supported yet" + f"Objects of type {constructor} are not supported yet" ) if hasattr(instance, "__setstate__"): @@ -125,7 +132,12 @@ def tree_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: class TreeNode(ReduceNode): - def __init__(self, state, load_context: LoadContext, trusted=False): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: bool | Sequence[str] = False, + ) -> None: super().__init__(state, load_context, constructor=Tree, trusted=trusted) self.trusted = self._get_trusted(trusted, [get_module(Tree) + ".Tree"]) @@ -137,12 +149,17 @@ def sgd_loss_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: class SGDNode(ReduceNode): - def __init__(self, state, load_context: LoadContext, trusted=False): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: bool | Sequence[str] = False, + ) -> None: # TODO: make sure trusted here makes sense and used. super().__init__( state, load_context, - constructor=gettype(state.get("__module__"), state.get("__class__")), + constructor=gettype(state["__module__"], state["__class__"]), trusted=False, ) self.trusted = self._get_trusted( @@ -173,7 +190,12 @@ def _DictWithDeprecatedKeys_get_state( # TODO: remove once support for sklearn<1.2 is dropped. class _DictWithDeprecatedKeysNode(Node): # _DictWithDeprecatedKeys is just a wrapper for dict - def __init__(self, state, load_context: LoadContext, trusted=False): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: bool | Sequence[str] = False, + ) -> None: super().__init__(state, load_context, trusted) self.trusted = [ get_module(_DictWithDeprecatedKeysNode) + "._DictWithDeprecatedKeys" diff --git a/skops/io/_utils.py b/skops/io/_utils.py index b1219233..37b036eb 100644 --- a/skops/io/_utils.py +++ b/skops/io/_utils.py @@ -4,7 +4,7 @@ import sys from dataclasses import dataclass, field from functools import singledispatch -from typing import Any +from typing import Any, Type from zipfile import ZipFile @@ -28,7 +28,7 @@ def _getattribute(obj, name): # This function is particularly used to detect the path of functions such as # ufuncs. It returns the full path, instead of returning the module name. -def whichmodule(obj, name): +def whichmodule(obj: Any, name: str) -> str: """Find the module an object belong to.""" module_name = getattr(obj, "__module__", None) if module_name is not None: @@ -53,17 +53,18 @@ def whichmodule(obj, name): # --------------------------------------------------------------------- -def _import_obj(module, cls_or_func, package=None): +def _import_obj(module: str, cls_or_func: str, package: str | None = None) -> Any: return getattr(importlib.import_module(module, package=package), cls_or_func) -def gettype(module_name, cls_or_func): +def gettype(module_name: str, cls_or_func: str) -> Type[Any]: if module_name and cls_or_func: return _import_obj(module_name, cls_or_func) - return None + raise ValueError(f"Object {cls_or_func} of module {module_name} is unknown") -def get_module(obj): + +def get_module(obj: Any) -> str: """Find module for given object If the module cannot be identified, it's assumed to be "__main__". @@ -144,14 +145,14 @@ def get_object(self, id: int) -> Any: @singledispatch -def _get_state(obj, save_context): +def _get_state(obj, save_context: SaveContext): # This function should never be called directly. Instead, it is used to # dispatch to the correct implementation of get_state for the given type of # its first argument. raise TypeError(f"Getting the state of type {type(obj)} is not supported yet") -def get_state(value, save_context): +def get_state(value, save_context: SaveContext) -> dict[str, Any]: # This is a helper function to try to get the state of an object. If it # fails with `get_state`, we try with json.dumps, if that fails, we raise # the original error alongside the json error. diff --git a/skops/io/tests/test_audit.py b/skops/io/tests/test_audit.py index aff04cff..35902b86 100644 --- a/skops/io/tests/test_audit.py +++ b/skops/io/tests/test_audit.py @@ -11,8 +11,7 @@ from sklearn.preprocessing import FunctionTransformer, StandardScaler from skops.io import dumps, get_untrusted_types -from skops.io._audit import audit_tree, check_type -from skops.io._dispatch import Node, get_tree, temp_setattr +from skops.io._audit import Node, audit_tree, check_type, get_tree, temp_setattr from skops.io._general import DictNode, dict_get_state from skops.io._utils import LoadContext, SaveContext, gettype @@ -102,9 +101,14 @@ def test_list_safety(values, is_safe): assert tree.is_safe() == is_safe -def test_gettype(): - # return None if one argument is None - assert gettype(module_name="test", cls_or_func=None) is None +def test_gettype_error(): + msg = "Object None of module test is unknown" + with pytest.raises(ValueError, match=msg): + gettype(module_name="test", cls_or_func=None) + + msg = "Object test of module None is unknown" + with pytest.raises(ValueError, match=msg): + gettype(module_name=None, cls_or_func="test") # ImportError if the module cannot be imported with pytest.raises(ImportError): @@ -115,9 +119,10 @@ def test_gettype(): "data, file, exception, message", [ ("not-none", "not-none", ValueError, "Only one of data or file"), + (None, None, ValueError, "Exactly one of data or file should be passed"), ("string", None, TypeError, "a bytes-like object is required, not 'str'"), ], - ids=["both", "string-data"], + ids=["both", "neither", "string-data"], ) def test_get_untrusted_types_validation(data, file, exception, message): with pytest.raises(exception, match=message): diff --git a/skops/io/tests/test_persist.py b/skops/io/tests/test_persist.py index 7c9f5056..f833620a 100644 --- a/skops/io/tests/test_persist.py +++ b/skops/io/tests/test_persist.py @@ -54,7 +54,7 @@ import skops from skops.io import dump, dumps, get_untrusted_types, load, loads -from skops.io._dispatch import NODE_TYPE_MAPPING, get_tree +from skops.io._audit import NODE_TYPE_MAPPING, get_tree from skops.io._sklearn import UNSUPPORTED_TYPES from skops.io._utils import LoadContext, SaveContext, _get_state, get_state from skops.io.exceptions import UnsupportedTypeException