diff --git a/skops/hub_utils/_hf_hub.py b/skops/hub_utils/_hf_hub.py index 39139a15..efbd5c25 100644 --- a/skops/hub_utils/_hf_hub.py +++ b/skops/hub_utils/_hf_hub.py @@ -264,6 +264,7 @@ def _create_config( does not support it. For more info, see https://intel.github.io/scikit-learn-intelex/. """ + # so that we don't have to explicitly add keys and they're added as a # dictionary if they are not found # see: https://stackoverflow.com/a/13151294/2536294 diff --git a/skops/io/_audit.py b/skops/io/_audit.py index 1e379308..35ac2415 100644 --- a/skops/io/_audit.py +++ b/skops/io/_audit.py @@ -4,11 +4,12 @@ from contextlib import contextmanager from typing import Any, Generator, Literal, Sequence, Type, Union +from ._protocol import PROTOCOL from ._trusted_types import PRIMITIVE_TYPE_NAMES from ._utils import LoadContext, get_module, get_type_paths from .exceptions import UntrustedTypesFoundException -NODE_TYPE_MAPPING = {} # type: ignore +NODE_TYPE_MAPPING: dict[tuple[str, int], Node] = {} def check_type( @@ -311,7 +312,7 @@ def _construct(self): return self.cached.construct() -NODE_TYPE_MAPPING["CachedNode"] = CachedNode +NODE_TYPE_MAPPING[("CachedNode", PROTOCOL)] = CachedNode # type: ignore def get_tree(state: dict[str, Any], load_context: LoadContext) -> Node: @@ -347,14 +348,27 @@ def get_tree(state: dict[str, Any], load_context: LoadContext) -> Node: # node's ``construct`` method caches the instance. return load_context.get_object(saved_id) - try: - node_cls = NODE_TYPE_MAPPING[state["__loader__"]] - except KeyError: - type_name = f"{state['__module__']}.{state['__class__']}" - raise TypeError( - f" Can't find loader {state['__loader__']} for type {type_name}." - ) + loader: str = state["__loader__"] + protocol = load_context.protocol + key = (loader, protocol) + + if key in NODE_TYPE_MAPPING: + node_cls = NODE_TYPE_MAPPING[key] + else: + # What probably happened here is that we released a new protocol. If + # there is no specific key for the old protocol, it means it is safe to + # use the current protocol instead, because this node was not changed. + key_new = (loader, PROTOCOL) + try: + node_cls = NODE_TYPE_MAPPING[key_new] + except KeyError: + # If we still cannot find the loader for this key, something went + # wrong. + type_name = f"{state['__module__']}.{state['__class__']}" + raise TypeError( + f" Can't find loader {state['__loader__']} for type {type_name} and " + f"protocol {protocol}." + ) loaded_tree = node_cls(state, load_context, trusted=False) # type: ignore - return loaded_tree diff --git a/skops/io/_general.py b/skops/io/_general.py index b5685c3f..189a4d86 100644 --- a/skops/io/_general.py +++ b/skops/io/_general.py @@ -11,6 +11,7 @@ import numpy as np from ._audit import Node, get_tree +from ._protocol import PROTOCOL from ._trusted_types import ( PRIMITIVE_TYPE_NAMES, SCIPY_UFUNC_TYPE_NAMES, @@ -180,13 +181,9 @@ def isnamedtuple(self, t) -> bool: def function_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: res = { - "__class__": obj.__class__.__name__, + "__class__": obj.__name__, "__module__": get_module(obj), "__loader__": "FunctionNode", - "content": { - "module_path": get_module(obj), - "function": obj.__name__, - }, } return res @@ -201,26 +198,20 @@ def __init__( super().__init__(state, load_context, trusted) # TODO: what do we trust? self.trusted = self._get_trusted(trusted, default=SCIPY_UFUNC_TYPE_NAMES) - self.children = {"content": state["content"]} + self.children = {} def _construct(self): - return _import_obj( - self.children["content"]["module_path"], - self.children["content"]["function"], - ) + return gettype(self.module_name, self.class_name) def _get_function_name(self) -> str: - return ( - self.children["content"]["module_path"] - + "." - + self.children["content"]["function"] - ) + return f"{self.module_name}.{self.class_name}" def get_unsafe_set(self) -> set[str]: - if (self.trusted is True) or (self._get_function_name() in self.trusted): + fn_name = self._get_function_name() + if (self.trusted is True) or (fn_name in self.trusted): return set() - return {self._get_function_name()} + return {fn_name} def partial_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: @@ -586,18 +577,18 @@ def _construct(self): ] NODE_TYPE_MAPPING = { - "DictNode": DictNode, - "ListNode": ListNode, - "SetNode": SetNode, - "TupleNode": TupleNode, - "BytesNode": BytesNode, - "BytearrayNode": BytearrayNode, - "SliceNode": SliceNode, - "FunctionNode": FunctionNode, - "MethodNode": MethodNode, - "PartialNode": PartialNode, - "TypeNode": TypeNode, - "ObjectNode": ObjectNode, - "JsonNode": JsonNode, - "OperatorFuncNode": OperatorFuncNode, + ("DictNode", PROTOCOL): DictNode, + ("ListNode", PROTOCOL): ListNode, + ("SetNode", PROTOCOL): SetNode, + ("TupleNode", PROTOCOL): TupleNode, + ("BytesNode", PROTOCOL): BytesNode, + ("BytearrayNode", PROTOCOL): BytearrayNode, + ("SliceNode", PROTOCOL): SliceNode, + ("FunctionNode", PROTOCOL): FunctionNode, + ("MethodNode", PROTOCOL): MethodNode, + ("PartialNode", PROTOCOL): PartialNode, + ("TypeNode", PROTOCOL): TypeNode, + ("ObjectNode", PROTOCOL): ObjectNode, + ("JsonNode", PROTOCOL): JsonNode, + ("OperatorFuncNode", PROTOCOL): OperatorFuncNode, } diff --git a/skops/io/_numpy.py b/skops/io/_numpy.py index e7bc40a9..1f2cef82 100644 --- a/skops/io/_numpy.py +++ b/skops/io/_numpy.py @@ -6,6 +6,8 @@ import numpy as np from ._audit import Node, get_tree +from ._general import function_get_state +from ._protocol import PROTOCOL from ._utils import LoadContext, SaveContext, get_module, get_state, gettype from .exceptions import UnsupportedTypeException @@ -195,22 +197,6 @@ def _construct(self): return gettype(self.module_name, self.class_name)(bit_generator=bit_generator) -# For numpy.ufunc we need to get the type from the type's module, but for other -# functions we get it from objet's module directly. Therefore sett a especial -# get_state method for them here. The load is the same as other functions. -def ufunc_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: - res = { - "__class__": obj.__class__.__name__, # ufunc - "__module__": get_module(type(obj)), # numpy - "__loader__": "FunctionNode", - "content": { - "module_path": get_module(obj), - "function": obj.__name__, - }, - } - return res - - def dtype_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: # we use numpy's internal save mechanism to store the dtype by # saving/loading an empty array with that dtype. @@ -247,16 +233,16 @@ def _construct(self): (np.generic, ndarray_get_state), (np.ndarray, ndarray_get_state), (np.ma.MaskedArray, maskedarray_get_state), - (np.ufunc, ufunc_get_state), + (np.ufunc, function_get_state), (np.dtype, dtype_get_state), (np.random.RandomState, random_state_get_state), (np.random.Generator, random_generator_get_state), ] # tuples of type and function that creates the instance of that type NODE_TYPE_MAPPING = { - "NdArrayNode": NdArrayNode, - "MaskedArrayNode": MaskedArrayNode, - "DTypeNode": DTypeNode, - "RandomStateNode": RandomStateNode, - "RandomGeneratorNode": RandomGeneratorNode, + ("NdArrayNode", PROTOCOL): NdArrayNode, + ("MaskedArrayNode", PROTOCOL): MaskedArrayNode, + ("DTypeNode", PROTOCOL): DTypeNode, + ("RandomStateNode", PROTOCOL): RandomStateNode, + ("RandomGeneratorNode", PROTOCOL): RandomGeneratorNode, } diff --git a/skops/io/_persist.py b/skops/io/_persist.py index ee552127..12796e05 100644 --- a/skops/io/_persist.py +++ b/skops/io/_persist.py @@ -13,8 +13,10 @@ from ._utils import LoadContext, SaveContext, _get_state, get_state # We load the dispatch functions from the corresponding modules and register -# them. +# them. Old protocols are found in the 'old/' directory, with the protocol +# version appended to the corresponding module name. modules = ["._general", "._numpy", "._scipy", "._sklearn"] +modules.extend([".old._general_v0"]) for module_name in modules: # register exposed functions for get_state and get_tree module = importlib.import_module(module_name, package="skops.io") @@ -123,9 +125,9 @@ def load(file: str | Path, trusted: bool | Sequence[str] = False) -> Any: """ with ZipFile(file, "r") as input_zip: - schema = input_zip.read("schema.json") - load_context = LoadContext(src=input_zip) - tree = get_tree(json.loads(schema), load_context) + schema = json.loads(input_zip.read("schema.json")) + load_context = LoadContext(src=input_zip, protocol=schema["protocol"]) + tree = get_tree(schema, load_context) audit_tree(tree, trusted) instance = tree.construct() @@ -164,7 +166,7 @@ def loads(data: bytes, trusted: bool | Sequence[str] = False) -> Any: with ZipFile(io.BytesIO(data), "r") as zip_file: schema = json.loads(zip_file.read("schema.json")) - load_context = LoadContext(src=zip_file) + load_context = LoadContext(src=zip_file, protocol=schema["protocol"]) tree = get_tree(schema, load_context) audit_tree(tree, trusted) instance = tree.construct() @@ -208,7 +210,9 @@ def get_untrusted_types( with ZipFile(content, "r") as zip_file: schema = json.loads(zip_file.read("schema.json")) - tree = get_tree(schema, load_context=LoadContext(src=zip_file)) + tree = get_tree( + schema, load_context=LoadContext(src=zip_file, protocol=schema["protocol"]) + ) untrusted_types = tree.get_unsafe_set() return sorted(untrusted_types) diff --git a/skops/io/_protocol.py b/skops/io/_protocol.py new file mode 100644 index 00000000..cacd4fbe --- /dev/null +++ b/skops/io/_protocol.py @@ -0,0 +1,26 @@ +"""The current protocol of the skops version + +Notes on updating the protocol: + +Every time that a backwards incompatible change to the skops format is made +for the first time within a release, the protocol should be bumped to the next +higher number. The old version of the Node, which knows how to deal with the +old state, should be preserved, registered, and tested. Let's give an example: + +- There is a BC breaking change in FunctionNode. +- Since it's the first BC breaking change in the skops format in this release, + bump skops.io._protocol.PROTOCOL (this file) from version X to X+1. +- Move the old FunctionNode code into 'skops/io/old/_general_vX.py', where 'X' + is the old protocol. +- Register the _general_vX.FunctionNode in NODE_TYPE_MAPPING inside of + _persist.py. +- Write a test in test_persist_old.py that shows that the old state can + still be loaded. Look at test_persist_old.test_function_v0 for inspiration. + +Now, if a user loads a FunctionNode state with version X using skops with +version Y>X, the old code will be used instead of the new one. For all other +node types, if there is no loader for version X, skops will automatically use +version Y instead. + +""" +PROTOCOL = 1 diff --git a/skops/io/_scipy.py b/skops/io/_scipy.py index cbf5c1d1..56b3e998 100644 --- a/skops/io/_scipy.py +++ b/skops/io/_scipy.py @@ -6,6 +6,7 @@ from scipy.sparse import load_npz, save_npz, spmatrix from ._audit import Node +from ._protocol import PROTOCOL from ._utils import LoadContext, SaveContext, get_module @@ -65,5 +66,5 @@ def _construct(self): NODE_TYPE_MAPPING = { # use 'spmatrix' to check if a matrix is a sparse matrix because that is # what scipy.sparse.issparse checks - "SparseMatrixNode": SparseMatrixNode, + ("SparseMatrixNode", PROTOCOL): SparseMatrixNode, } diff --git a/skops/io/_sklearn.py b/skops/io/_sklearn.py index 14ba4a87..4d302267 100644 --- a/skops/io/_sklearn.py +++ b/skops/io/_sklearn.py @@ -4,6 +4,8 @@ from sklearn.cluster import Birch +from ._protocol import PROTOCOL + try: # TODO: remove once support for sklearn<1.2 is dropped. See #187 from sklearn.covariance._graph_lasso import _DictWithDeprecatedKeys @@ -232,8 +234,8 @@ def _construct(self): # tuples of type and function that creates the instance of that type NODE_TYPE_MAPPING = { - "SGDNode": SGDNode, - "TreeNode": TreeNode, + ("SGDNode", PROTOCOL): SGDNode, + ("TreeNode", PROTOCOL): TreeNode, } # TODO: remove once support for sklearn<1.2 is dropped. @@ -244,5 +246,5 @@ def _construct(self): (_DictWithDeprecatedKeys, _DictWithDeprecatedKeys_get_state) ) NODE_TYPE_MAPPING[ - "_DictWithDeprecatedKeysNode" + ("_DictWithDeprecatedKeysNode", PROTOCOL) ] = _DictWithDeprecatedKeysNode # type: ignore diff --git a/skops/io/_utils.py b/skops/io/_utils.py index 17eaf8c1..f147d15a 100644 --- a/skops/io/_utils.py +++ b/skops/io/_utils.py @@ -7,6 +7,8 @@ from typing import Any, Type from zipfile import ZipFile +from ._protocol import PROTOCOL + # The following two functions are copied from cpython's pickle.py file. # --------------------------------------------------------------------- @@ -83,10 +85,6 @@ def get_module(obj: Any) -> str: return whichmodule(obj, obj.__name__) -# For now, there is just one protocol version -DEFAULT_PROTOCOL = 0 - - @dataclass(frozen=True) class SaveContext: """Context required for saving the objects @@ -105,7 +103,7 @@ class SaveContext: """ zip_file: ZipFile - protocol: int = DEFAULT_PROTOCOL + protocol: int = PROTOCOL memo: dict[int, Any] = field(default_factory=dict) def memoize(self, obj: Any) -> int: @@ -135,6 +133,7 @@ class LoadContext: """ src: ZipFile + protocol: int memo: dict[int, Any] = field(default_factory=dict) def memoize(self, obj: Any, id: int) -> None: diff --git a/skops/io/old/_general_v0.py b/skops/io/old/_general_v0.py new file mode 100644 index 00000000..5580e63e --- /dev/null +++ b/skops/io/old/_general_v0.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from typing import Any, Sequence + +from skops.io._audit import Node +from skops.io._trusted_types import SCIPY_UFUNC_TYPE_NAMES +from skops.io._utils import LoadContext, _import_obj + +PROTOCOL = 0 + + +class FunctionNode(Node): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: bool | Sequence[str] = False, + ) -> None: + super().__init__(state, load_context, trusted) + # TODO: what do we trust? + self.trusted = self._get_trusted(trusted, default=SCIPY_UFUNC_TYPE_NAMES) + self.children = {"content": state["content"]} + + def _construct(self): + return _import_obj( + self.children["content"]["module_path"], + self.children["content"]["function"], + ) + + def _get_function_name(self) -> str: # pragma: no cover + return ( + self.children["content"]["module_path"] + + "." + + self.children["content"]["function"] + ) + + def get_unsafe_set(self) -> set[str]: # pragma: no cover + if (self.trusted is True) or (self._get_function_name() in self.trusted): + return set() + + return {self._get_function_name()} + + +NODE_TYPE_MAPPING = { + ("FunctionNode", PROTOCOL): FunctionNode, +} diff --git a/skops/io/tests/_utils.py b/skops/io/tests/_utils.py index 9081a9fc..ce9c45d9 100644 --- a/skops/io/tests/_utils.py +++ b/skops/io/tests/_utils.py @@ -1,5 +1,10 @@ +from __future__ import annotations + +import io +import json import sys import warnings +from zipfile import ZipFile import numpy as np from scipy import sparse @@ -173,3 +178,81 @@ def assert_method_outputs_equal(estimator, loaded, X): X_out1 = getattr(estimator, method)(X) X_out2 = getattr(loaded, method)(X) assert_allclose_dense_sparse(X_out1, X_out2, err_msg=err_msg, atol=ATOL) + + +def downgrade_state(*, data: bytes, keys: list[str], old_state: dict, protocol: int): + """Function to downgrade the persisted state of a skops object. + + This function is important for testing upgrades to the skops persistence + protocol. When an upgrade is made, we add a test to ensure that the old + state can still be loaded successfully. For this, we need to generate a + state that looks like it came from the previous protocol. This function + helps doing that. + + The caller should pass the new state, the path to the sub-state that needs + to be downgraded, the actual old state to be downgraded to, and the protocol + of that old version. Then this function will replace the new state with the + old state and insert the old protocol number. It also adds an ``__id__`` + field, which is expected for memoization. + + Here is an example of how to use it: + + .. code:: python + + estimator = ... + # get the state of the object using current protocol + dumped = sio.dumps(estimator) + # let's assume that estimator.foo.bar was changed + keys = ["foo", "bar"] + old_state = old_get_state_function(bar) + downgraded = downgrad_state( + data=dumped, + keys=keys, + old_state=old_state + protocol=current_protocol - 1, + ) + # check that this does not raise an error: + sio.loads(downgrade, trusted=...) + + Parameters + ---------- + data : bytes + The old state, as generated by ``skops.io.dumps``. + + keys : list of str + The keys that lead to the old state. E.g. if we want to replace + ``state["foo"]["bar"]``, then keys should be ``["foo", "bar"]``. + + old_state : dict + The old state, as would be produced by the old ``get_state`` function. + + protocol : int + The protocol number corresponding to the old state. + + Returns + ------- + bytes + The old state, as would have been dumped by ``skops.io.dumps``. + + """ + # load from bytes + with ZipFile(io.BytesIO(data), "r") as zip_file: + schema = json.loads(zip_file.read("schema.json")) + + # replace schema + schema["protocol"] = protocol + + # replace state using old state + state = schema + for key in keys[:-1]: + state = state[key] + state[keys[-1]] = old_state + + # there has to be an __id__ field for memoization + state[keys[-1]]["__id__"] = id(schema) + + # dump into bytes + buffer = io.BytesIO() + with ZipFile(buffer, "w") as zip_file: + zip_file.writestr("schema.json", json.dumps(schema, indent=2)) + return buffer.getbuffer().tobytes() diff --git a/skops/io/tests/test_audit.py b/skops/io/tests/test_audit.py index 71914b4d..d7a9053d 100644 --- a/skops/io/tests/test_audit.py +++ b/skops/io/tests/test_audit.py @@ -41,7 +41,7 @@ def test_check_type(module_name, type_name, trusted, expected): def test_audit_tree_untrusted(): var = {"a": CustomType(1), 2: CustomType(2)} state = dict_get_state(var, SaveContext(None, 0, {})) - node = DictNode(state, LoadContext(None), trusted=False) + node = DictNode(state, LoadContext(None, -1), trusted=False) with pytest.raises( TypeError, match=re.escape( @@ -64,7 +64,7 @@ def test_audit_tree_defaults(): # test that the default types are trusted var = {"a": 1, 2: "b"} state = dict_get_state(var, SaveContext(None, 0, {})) - node = DictNode(state, LoadContext(None), trusted=False) + node = DictNode(state, LoadContext(None, -1), trusted=False) audit_tree(node, trusted=[]) @@ -97,7 +97,7 @@ def test_list_safety(values, is_safe): with ZipFile(io.BytesIO(content), "r") as zip_file: schema = json.loads(zip_file.read("schema.json")) - tree = get_tree(schema, load_context=LoadContext(src=zip_file)) + tree = get_tree(schema, load_context=LoadContext(src=zip_file, protocol=-1)) assert tree.is_safe() == is_safe diff --git a/skops/io/tests/test_persist.py b/skops/io/tests/test_persist.py index 7a6283e2..19a6c7eb 100644 --- a/skops/io/tests/test_persist.py +++ b/skops/io/tests/test_persist.py @@ -523,7 +523,7 @@ def fit(self, X, y=None, **fit_params): schema = json.loads(ZipFile(io.BytesIO(dumped)).read("schema.json")) # check some schema metainfo - assert schema["protocol"] == skops.io._utils.DEFAULT_PROTOCOL + assert schema["protocol"] == skops.io._protocol.PROTOCOL assert schema["_skops_version"] == skops.__version__ # additionally, check following metainfo: class, module, and version @@ -657,7 +657,7 @@ def test_get_tree_unknown_type_error_msg(): state["__loader__"] = "this_get_tree_does_not_exist" msg = "Can't find loader this_get_tree_does_not_exist for type builtins.tuple." with pytest.raises(TypeError, match=msg): - get_tree(state, LoadContext(None)) + get_tree(state, LoadContext(None, -1)) class _BoundMethodHolder: @@ -953,3 +953,21 @@ def test_persist_operator_raises_untrusted(op): est = FunctionTransformer(func) with pytest.raises(UntrustedTypesFoundException, match=name): loads(dumps(est), trusted=False) + + +def dummy_func(X): + return X + + +@pytest.mark.parametrize("func", [np.sqrt, len, special.exp10, dummy_func]) +def test_persist_function(func): + estimator = FunctionTransformer(func=func) + X, y = [0, 1], [2, 3] + estimator.fit(X, y) + + dumped = dumps(estimator) + loaded = loads(dumped, trusted=True) + + # check that loaded estimator is identical + assert_params_equal(estimator.__dict__, loaded.__dict__) + assert_method_outputs_equal(estimator, loaded, X) diff --git a/skops/io/tests/test_persist_old.py b/skops/io/tests/test_persist_old.py new file mode 100644 index 00000000..69f297fa --- /dev/null +++ b/skops/io/tests/test_persist_old.py @@ -0,0 +1,68 @@ +"""Persistence tests for old versions of the protocol""" + +from __future__ import annotations + +import numpy as np +import pytest +from scipy import special +from sklearn.preprocessing import FunctionTransformer + +from skops.io import dumps, loads +from skops.io._utils import get_module +from skops.io.tests._utils import ( + assert_method_outputs_equal, + assert_params_equal, + downgrade_state, +) + +############# +# VERSION 0 # +############# + + +def dummy_func(X): + return X + + +@pytest.mark.parametrize("func", [np.sqrt, len, special.exp10, dummy_func]) +def test_persist_function_v0(func): + call_count = 0 + + # function_get_state as it was for protocol 0 + def old_function_get_state(obj, save_context): + # added for testing + nonlocal call_count + call_count += 1 + # end + + res = { + "__class__": obj.__class__.__name__, + "__module__": get_module(obj), + "__loader__": "FunctionNode", + "content": { + "module_path": get_module(obj), + "function": obj.__name__, + }, + } + return res + + estimator = FunctionTransformer(func=func) + X, y = [0, 1], [2, 3] + estimator.fit(X, y) + + dumped = dumps(estimator) + # importent: downgrade the state to mimic older version + downgraded = downgrade_state( + data=dumped, + keys=["content", "content", "func"], + old_state=old_function_get_state(func, None), + protocol=0, + ) + loaded = loads(downgraded, trusted=True) + + # sanity check: ensure that the old get_state function was really called + assert call_count == 1 + + # check that loaded estimator is identical + assert_params_equal(estimator.__dict__, loaded.__dict__) + assert_method_outputs_equal(estimator, loaded, X)