From 6a36dcc9abe6eb14dfb250a8eb8e95f7139c5a40 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Thu, 16 Mar 2023 18:04:58 +0100 Subject: [PATCH 1/8] POC: Proposal for how to deal with protocol update This PR is a POC that is intended as a basis for discussing how we can deal with updating the skops persistence protocol. The problem in general occurs if we make a change to the state that is stored with an object. For instance, we could store an additional field. When we try to load old state without that key, there would be an error. Therefore, we have to somehow adjust our loading code to deal with this possibility. First of all, we want to avoid having a bunch of conditionals in our loading code that checks what protocol is used and then does this or that thing. This approach would become messy and error prone quite quickly, because we intend to support old protocols for a very long time. This proposal here would allow to write a clean implementation of the new deserialization code without regard for backwards compatibility. The old code, however, would not be deleted but instead moved to a different module and registered with its old protocol number. Then, while loading, we would check the protocol stored in the schema and use the old code if there is a match for it. If there is no match, we assume we can safely use the current protocol instead. This is just a draft and I haven't tested it on a real example. If we agree that this is the way forward, I would expand on it and test it properly. But first we should agree that this is the way to go and check if there are no problems with this approach. Below are the instructions with how to deal with a change in protocol: Every time that a backwards incompatible change to the skops format is made for the first time within a release, the protocol should be bumped to the next higher number. The old version of the Node, which knows how to deal with the old state, should be preserved and registered. Let's give an example: - There is a BC breaking change in FunctionNode. - Since it's the first BC breaking change in the skops format in this release, bump skops.io._protocol.PROTOCOL (this file) from version X to Y. - Move the old FunctionNode code into 'skops/io/old/_general_vX.py', where 'X' is the old protocol. - Register the _general_vX.FunctionNode in NODE_TYPE_MAPPING inside of _persist.py. Now, if a user loads a FunctionNode state with version X using skops with version Y, the old code will be used instead of the new one. For all other node types, if there is no loader for version X, skops will automatically use version Y instead. --- skops/io/_audit.py | 34 ++++++++++++++++++++++++---------- skops/io/_general.py | 29 +++++++++++++++-------------- skops/io/_numpy.py | 11 ++++++----- skops/io/_persist.py | 16 ++++++++++------ skops/io/_scipy.py | 3 ++- skops/io/_sklearn.py | 12 +++++++----- skops/io/_utils.py | 9 ++++----- skops/io/tests/test_audit.py | 6 +++--- skops/io/tests/test_persist.py | 4 ++-- 9 files changed, 73 insertions(+), 51 deletions(-) diff --git a/skops/io/_audit.py b/skops/io/_audit.py index 1e379308..35ac2415 100644 --- a/skops/io/_audit.py +++ b/skops/io/_audit.py @@ -4,11 +4,12 @@ from contextlib import contextmanager from typing import Any, Generator, Literal, Sequence, Type, Union +from ._protocol import PROTOCOL from ._trusted_types import PRIMITIVE_TYPE_NAMES from ._utils import LoadContext, get_module, get_type_paths from .exceptions import UntrustedTypesFoundException -NODE_TYPE_MAPPING = {} # type: ignore +NODE_TYPE_MAPPING: dict[tuple[str, int], Node] = {} def check_type( @@ -311,7 +312,7 @@ def _construct(self): return self.cached.construct() -NODE_TYPE_MAPPING["CachedNode"] = CachedNode +NODE_TYPE_MAPPING[("CachedNode", PROTOCOL)] = CachedNode # type: ignore def get_tree(state: dict[str, Any], load_context: LoadContext) -> Node: @@ -347,14 +348,27 @@ def get_tree(state: dict[str, Any], load_context: LoadContext) -> Node: # node's ``construct`` method caches the instance. return load_context.get_object(saved_id) - try: - node_cls = NODE_TYPE_MAPPING[state["__loader__"]] - except KeyError: - type_name = f"{state['__module__']}.{state['__class__']}" - raise TypeError( - f" Can't find loader {state['__loader__']} for type {type_name}." - ) + loader: str = state["__loader__"] + protocol = load_context.protocol + key = (loader, protocol) + + if key in NODE_TYPE_MAPPING: + node_cls = NODE_TYPE_MAPPING[key] + else: + # What probably happened here is that we released a new protocol. If + # there is no specific key for the old protocol, it means it is safe to + # use the current protocol instead, because this node was not changed. + key_new = (loader, PROTOCOL) + try: + node_cls = NODE_TYPE_MAPPING[key_new] + except KeyError: + # If we still cannot find the loader for this key, something went + # wrong. + type_name = f"{state['__module__']}.{state['__class__']}" + raise TypeError( + f" Can't find loader {state['__loader__']} for type {type_name} and " + f"protocol {protocol}." + ) loaded_tree = node_cls(state, load_context, trusted=False) # type: ignore - return loaded_tree diff --git a/skops/io/_general.py b/skops/io/_general.py index b5685c3f..a265a7bc 100644 --- a/skops/io/_general.py +++ b/skops/io/_general.py @@ -11,6 +11,7 @@ import numpy as np from ._audit import Node, get_tree +from ._protocol import PROTOCOL from ._trusted_types import ( PRIMITIVE_TYPE_NAMES, SCIPY_UFUNC_TYPE_NAMES, @@ -586,18 +587,18 @@ def _construct(self): ] NODE_TYPE_MAPPING = { - "DictNode": DictNode, - "ListNode": ListNode, - "SetNode": SetNode, - "TupleNode": TupleNode, - "BytesNode": BytesNode, - "BytearrayNode": BytearrayNode, - "SliceNode": SliceNode, - "FunctionNode": FunctionNode, - "MethodNode": MethodNode, - "PartialNode": PartialNode, - "TypeNode": TypeNode, - "ObjectNode": ObjectNode, - "JsonNode": JsonNode, - "OperatorFuncNode": OperatorFuncNode, + ("DictNode", PROTOCOL): DictNode, + ("ListNode", PROTOCOL): ListNode, + ("SetNode", PROTOCOL): SetNode, + ("TupleNode", PROTOCOL): TupleNode, + ("BytesNode", PROTOCOL): BytesNode, + ("BytearrayNode", PROTOCOL): BytearrayNode, + ("SliceNode", PROTOCOL): SliceNode, + ("FunctionNode", PROTOCOL): FunctionNode, + ("MethodNode", PROTOCOL): MethodNode, + ("PartialNode", PROTOCOL): PartialNode, + ("TypeNode", PROTOCOL): TypeNode, + ("ObjectNode", PROTOCOL): ObjectNode, + ("JsonNode", PROTOCOL): JsonNode, + ("OperatorFuncNode", PROTOCOL): OperatorFuncNode, } diff --git a/skops/io/_numpy.py b/skops/io/_numpy.py index e7bc40a9..c477f329 100644 --- a/skops/io/_numpy.py +++ b/skops/io/_numpy.py @@ -6,6 +6,7 @@ import numpy as np from ._audit import Node, get_tree +from ._protocol import PROTOCOL from ._utils import LoadContext, SaveContext, get_module, get_state, gettype from .exceptions import UnsupportedTypeException @@ -254,9 +255,9 @@ def _construct(self): ] # tuples of type and function that creates the instance of that type NODE_TYPE_MAPPING = { - "NdArrayNode": NdArrayNode, - "MaskedArrayNode": MaskedArrayNode, - "DTypeNode": DTypeNode, - "RandomStateNode": RandomStateNode, - "RandomGeneratorNode": RandomGeneratorNode, + ("NdArrayNode", PROTOCOL): NdArrayNode, + ("MaskedArrayNode", PROTOCOL): MaskedArrayNode, + ("DTypeNode", PROTOCOL): DTypeNode, + ("RandomStateNode", PROTOCOL): RandomStateNode, + ("RandomGeneratorNode", PROTOCOL): RandomGeneratorNode, } diff --git a/skops/io/_persist.py b/skops/io/_persist.py index ee552127..12796e05 100644 --- a/skops/io/_persist.py +++ b/skops/io/_persist.py @@ -13,8 +13,10 @@ from ._utils import LoadContext, SaveContext, _get_state, get_state # We load the dispatch functions from the corresponding modules and register -# them. +# them. Old protocols are found in the 'old/' directory, with the protocol +# version appended to the corresponding module name. modules = ["._general", "._numpy", "._scipy", "._sklearn"] +modules.extend([".old._general_v0"]) for module_name in modules: # register exposed functions for get_state and get_tree module = importlib.import_module(module_name, package="skops.io") @@ -123,9 +125,9 @@ def load(file: str | Path, trusted: bool | Sequence[str] = False) -> Any: """ with ZipFile(file, "r") as input_zip: - schema = input_zip.read("schema.json") - load_context = LoadContext(src=input_zip) - tree = get_tree(json.loads(schema), load_context) + schema = json.loads(input_zip.read("schema.json")) + load_context = LoadContext(src=input_zip, protocol=schema["protocol"]) + tree = get_tree(schema, load_context) audit_tree(tree, trusted) instance = tree.construct() @@ -164,7 +166,7 @@ def loads(data: bytes, trusted: bool | Sequence[str] = False) -> Any: with ZipFile(io.BytesIO(data), "r") as zip_file: schema = json.loads(zip_file.read("schema.json")) - load_context = LoadContext(src=zip_file) + load_context = LoadContext(src=zip_file, protocol=schema["protocol"]) tree = get_tree(schema, load_context) audit_tree(tree, trusted) instance = tree.construct() @@ -208,7 +210,9 @@ def get_untrusted_types( with ZipFile(content, "r") as zip_file: schema = json.loads(zip_file.read("schema.json")) - tree = get_tree(schema, load_context=LoadContext(src=zip_file)) + tree = get_tree( + schema, load_context=LoadContext(src=zip_file, protocol=schema["protocol"]) + ) untrusted_types = tree.get_unsafe_set() return sorted(untrusted_types) diff --git a/skops/io/_scipy.py b/skops/io/_scipy.py index cbf5c1d1..56b3e998 100644 --- a/skops/io/_scipy.py +++ b/skops/io/_scipy.py @@ -6,6 +6,7 @@ from scipy.sparse import load_npz, save_npz, spmatrix from ._audit import Node +from ._protocol import PROTOCOL from ._utils import LoadContext, SaveContext, get_module @@ -65,5 +66,5 @@ def _construct(self): NODE_TYPE_MAPPING = { # use 'spmatrix' to check if a matrix is a sparse matrix because that is # what scipy.sparse.issparse checks - "SparseMatrixNode": SparseMatrixNode, + ("SparseMatrixNode", PROTOCOL): SparseMatrixNode, } diff --git a/skops/io/_sklearn.py b/skops/io/_sklearn.py index 14ba4a87..085a81f3 100644 --- a/skops/io/_sklearn.py +++ b/skops/io/_sklearn.py @@ -4,6 +4,8 @@ from sklearn.cluster import Birch +from ._protocol import PROTOCOL + try: # TODO: remove once support for sklearn<1.2 is dropped. See #187 from sklearn.covariance._graph_lasso import _DictWithDeprecatedKeys @@ -232,8 +234,8 @@ def _construct(self): # tuples of type and function that creates the instance of that type NODE_TYPE_MAPPING = { - "SGDNode": SGDNode, - "TreeNode": TreeNode, + ("SGDNode", PROTOCOL): SGDNode, + ("TreeNode", PROTOCOL): TreeNode, } # TODO: remove once support for sklearn<1.2 is dropped. @@ -243,6 +245,6 @@ def _construct(self): GET_STATE_DISPATCH_FUNCTIONS.append( (_DictWithDeprecatedKeys, _DictWithDeprecatedKeys_get_state) ) - NODE_TYPE_MAPPING[ - "_DictWithDeprecatedKeysNode" - ] = _DictWithDeprecatedKeysNode # type: ignore + NODE_TYPE_MAPPING[("_DictWithDeprecatedKeysNode", PROTOCOL)] = ( + _DictWithDeprecatedKeysNode # type: ignore + ) diff --git a/skops/io/_utils.py b/skops/io/_utils.py index 17eaf8c1..f147d15a 100644 --- a/skops/io/_utils.py +++ b/skops/io/_utils.py @@ -7,6 +7,8 @@ from typing import Any, Type from zipfile import ZipFile +from ._protocol import PROTOCOL + # The following two functions are copied from cpython's pickle.py file. # --------------------------------------------------------------------- @@ -83,10 +85,6 @@ def get_module(obj: Any) -> str: return whichmodule(obj, obj.__name__) -# For now, there is just one protocol version -DEFAULT_PROTOCOL = 0 - - @dataclass(frozen=True) class SaveContext: """Context required for saving the objects @@ -105,7 +103,7 @@ class SaveContext: """ zip_file: ZipFile - protocol: int = DEFAULT_PROTOCOL + protocol: int = PROTOCOL memo: dict[int, Any] = field(default_factory=dict) def memoize(self, obj: Any) -> int: @@ -135,6 +133,7 @@ class LoadContext: """ src: ZipFile + protocol: int memo: dict[int, Any] = field(default_factory=dict) def memoize(self, obj: Any, id: int) -> None: diff --git a/skops/io/tests/test_audit.py b/skops/io/tests/test_audit.py index 71914b4d..d7a9053d 100644 --- a/skops/io/tests/test_audit.py +++ b/skops/io/tests/test_audit.py @@ -41,7 +41,7 @@ def test_check_type(module_name, type_name, trusted, expected): def test_audit_tree_untrusted(): var = {"a": CustomType(1), 2: CustomType(2)} state = dict_get_state(var, SaveContext(None, 0, {})) - node = DictNode(state, LoadContext(None), trusted=False) + node = DictNode(state, LoadContext(None, -1), trusted=False) with pytest.raises( TypeError, match=re.escape( @@ -64,7 +64,7 @@ def test_audit_tree_defaults(): # test that the default types are trusted var = {"a": 1, 2: "b"} state = dict_get_state(var, SaveContext(None, 0, {})) - node = DictNode(state, LoadContext(None), trusted=False) + node = DictNode(state, LoadContext(None, -1), trusted=False) audit_tree(node, trusted=[]) @@ -97,7 +97,7 @@ def test_list_safety(values, is_safe): with ZipFile(io.BytesIO(content), "r") as zip_file: schema = json.loads(zip_file.read("schema.json")) - tree = get_tree(schema, load_context=LoadContext(src=zip_file)) + tree = get_tree(schema, load_context=LoadContext(src=zip_file, protocol=-1)) assert tree.is_safe() == is_safe diff --git a/skops/io/tests/test_persist.py b/skops/io/tests/test_persist.py index 11b16ae3..5145fa1c 100644 --- a/skops/io/tests/test_persist.py +++ b/skops/io/tests/test_persist.py @@ -521,7 +521,7 @@ def fit(self, X, y=None, **fit_params): schema = json.loads(ZipFile(io.BytesIO(dumped)).read("schema.json")) # check some schema metainfo - assert schema["protocol"] == skops.io._utils.DEFAULT_PROTOCOL + assert schema["protocol"] == skops.io._protocol.PROTOCOL assert schema["_skops_version"] == skops.__version__ # additionally, check following metainfo: class, module, and version @@ -655,7 +655,7 @@ def test_get_tree_unknown_type_error_msg(): state["__loader__"] = "this_get_tree_does_not_exist" msg = "Can't find loader this_get_tree_does_not_exist for type builtins.tuple." with pytest.raises(TypeError, match=msg): - get_tree(state, LoadContext(None)) + get_tree(state, LoadContext(None, -1)) class _BoundMethodHolder: From 35280e73f3b9be47b742f0a490e09e2a1c9fa2d0 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Thu, 16 Mar 2023 18:18:27 +0100 Subject: [PATCH 2/8] Fixing black Not sure why pre-commit didn't catch that... --- skops/hub_utils/_hf_hub.py | 1 + skops/io/_sklearn.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/skops/hub_utils/_hf_hub.py b/skops/hub_utils/_hf_hub.py index 39139a15..efbd5c25 100644 --- a/skops/hub_utils/_hf_hub.py +++ b/skops/hub_utils/_hf_hub.py @@ -264,6 +264,7 @@ def _create_config( does not support it. For more info, see https://intel.github.io/scikit-learn-intelex/. """ + # so that we don't have to explicitly add keys and they're added as a # dictionary if they are not found # see: https://stackoverflow.com/a/13151294/2536294 diff --git a/skops/io/_sklearn.py b/skops/io/_sklearn.py index 085a81f3..4d302267 100644 --- a/skops/io/_sklearn.py +++ b/skops/io/_sklearn.py @@ -245,6 +245,6 @@ def _construct(self): GET_STATE_DISPATCH_FUNCTIONS.append( (_DictWithDeprecatedKeys, _DictWithDeprecatedKeys_get_state) ) - NODE_TYPE_MAPPING[("_DictWithDeprecatedKeysNode", PROTOCOL)] = ( - _DictWithDeprecatedKeysNode # type: ignore - ) + NODE_TYPE_MAPPING[ + ("_DictWithDeprecatedKeysNode", PROTOCOL) + ] = _DictWithDeprecatedKeysNode # type: ignore From b535d9af141da1c477e9753064d825c57f760715 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Thu, 16 Mar 2023 18:20:51 +0100 Subject: [PATCH 3/8] Check in missing module file --- skops/io/_protocol.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 skops/io/_protocol.py diff --git a/skops/io/_protocol.py b/skops/io/_protocol.py new file mode 100644 index 00000000..e7dc0a45 --- /dev/null +++ b/skops/io/_protocol.py @@ -0,0 +1,23 @@ +"""The current protocol of the skops version + +Notes on updating the protocol: + +Every time that a backwards incompatible change to the skops format is made +for the first time within a release, the protocol should be bumped to the next +higher number. The old version of the Node, which knows how to deal with the +old state, should be preserved and registered. Let's give an example: + +- There is a BC breaking change in FunctionNode. +- Since it's the first BC breaking change in the skops format in this release, + bump skops.io._protocol.PROTOCOL (this file) from version X to Y. +- Move the old FunctionNode code into 'skops/io/old/_general_vX.py', where 'X' + is the old protocol. +- Register the _general_vX.FunctionNode in NODE_TYPE_MAPPING inside of + _persist.py. + +Now, if a user loads a FunctionNode state with version X using skops with +version Y, the old code will be used instead of the new one. For all other +node types, if there is no loader for version X, skops will automatically use +version Y instead. +""" +PROTOCOL = 1 From b58f169fd7b52ab8f95534f656601c783b77ed5d Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Thu, 16 Mar 2023 18:24:41 +0100 Subject: [PATCH 4/8] Check in even more missing stuff --- skops/io/old/_general_v0.py | 46 +++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 skops/io/old/_general_v0.py diff --git a/skops/io/old/_general_v0.py b/skops/io/old/_general_v0.py new file mode 100644 index 00000000..a2eed3a2 --- /dev/null +++ b/skops/io/old/_general_v0.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from typing import Any, Sequence + +from skops.io._audit import Node +from skops.io._trusted_types import SCIPY_UFUNC_TYPE_NAMES +from skops.io._utils import LoadContext, _import_obj + +PROTOCOL = 0 + + +class FunctionNode(Node): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: bool | Sequence[str] = False, + ) -> None: + super().__init__(state, load_context, trusted) + # TODO: what do we trust? + self.trusted = self._get_trusted(trusted, default=SCIPY_UFUNC_TYPE_NAMES) + self.children = {"content": state["content"]} + + def _construct(self): + return _import_obj( + self.children["content"]["module_path"], + self.children["content"]["function"], + ) + + def _get_function_name(self) -> str: + return ( + self.children["content"]["module_path"] + + "." + + self.children["content"]["function"] + ) + + def get_unsafe_set(self) -> set[str]: + if (self.trusted is True) or (self._get_function_name() in self.trusted): + return set() + + return {self._get_function_name()} + + +NODE_TYPE_MAPPING = { + ("FunctionNode", PROTOCOL): FunctionNode, +} From 88524803c6682cf6c138a4d0dc6e318563c36e10 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Tue, 21 Mar 2023 17:53:07 +0100 Subject: [PATCH 5/8] Upgrade protocol to use new FunctionNode Taking the changes from #320 to how functions are persisted and adding them here. Therefore, this PR supersedes #320. That change is added here because it is the perfect test case for the update route of the skops protocol. --- skops/io/_general.py | 24 +++++----------- skops/io/_numpy.py | 19 ++---------- skops/io/_protocol.py | 9 ++++-- skops/io/tests/_utils.py | 62 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 77 insertions(+), 37 deletions(-) diff --git a/skops/io/_general.py b/skops/io/_general.py index a265a7bc..189a4d86 100644 --- a/skops/io/_general.py +++ b/skops/io/_general.py @@ -181,13 +181,9 @@ def isnamedtuple(self, t) -> bool: def function_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: res = { - "__class__": obj.__class__.__name__, + "__class__": obj.__name__, "__module__": get_module(obj), "__loader__": "FunctionNode", - "content": { - "module_path": get_module(obj), - "function": obj.__name__, - }, } return res @@ -202,26 +198,20 @@ def __init__( super().__init__(state, load_context, trusted) # TODO: what do we trust? self.trusted = self._get_trusted(trusted, default=SCIPY_UFUNC_TYPE_NAMES) - self.children = {"content": state["content"]} + self.children = {} def _construct(self): - return _import_obj( - self.children["content"]["module_path"], - self.children["content"]["function"], - ) + return gettype(self.module_name, self.class_name) def _get_function_name(self) -> str: - return ( - self.children["content"]["module_path"] - + "." - + self.children["content"]["function"] - ) + return f"{self.module_name}.{self.class_name}" def get_unsafe_set(self) -> set[str]: - if (self.trusted is True) or (self._get_function_name() in self.trusted): + fn_name = self._get_function_name() + if (self.trusted is True) or (fn_name in self.trusted): return set() - return {self._get_function_name()} + return {fn_name} def partial_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: diff --git a/skops/io/_numpy.py b/skops/io/_numpy.py index c477f329..1f2cef82 100644 --- a/skops/io/_numpy.py +++ b/skops/io/_numpy.py @@ -6,6 +6,7 @@ import numpy as np from ._audit import Node, get_tree +from ._general import function_get_state from ._protocol import PROTOCOL from ._utils import LoadContext, SaveContext, get_module, get_state, gettype from .exceptions import UnsupportedTypeException @@ -196,22 +197,6 @@ def _construct(self): return gettype(self.module_name, self.class_name)(bit_generator=bit_generator) -# For numpy.ufunc we need to get the type from the type's module, but for other -# functions we get it from objet's module directly. Therefore sett a especial -# get_state method for them here. The load is the same as other functions. -def ufunc_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: - res = { - "__class__": obj.__class__.__name__, # ufunc - "__module__": get_module(type(obj)), # numpy - "__loader__": "FunctionNode", - "content": { - "module_path": get_module(obj), - "function": obj.__name__, - }, - } - return res - - def dtype_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: # we use numpy's internal save mechanism to store the dtype by # saving/loading an empty array with that dtype. @@ -248,7 +233,7 @@ def _construct(self): (np.generic, ndarray_get_state), (np.ndarray, ndarray_get_state), (np.ma.MaskedArray, maskedarray_get_state), - (np.ufunc, ufunc_get_state), + (np.ufunc, function_get_state), (np.dtype, dtype_get_state), (np.random.RandomState, random_state_get_state), (np.random.Generator, random_generator_get_state), diff --git a/skops/io/_protocol.py b/skops/io/_protocol.py index e7dc0a45..cacd4fbe 100644 --- a/skops/io/_protocol.py +++ b/skops/io/_protocol.py @@ -5,19 +5,22 @@ Every time that a backwards incompatible change to the skops format is made for the first time within a release, the protocol should be bumped to the next higher number. The old version of the Node, which knows how to deal with the -old state, should be preserved and registered. Let's give an example: +old state, should be preserved, registered, and tested. Let's give an example: - There is a BC breaking change in FunctionNode. - Since it's the first BC breaking change in the skops format in this release, - bump skops.io._protocol.PROTOCOL (this file) from version X to Y. + bump skops.io._protocol.PROTOCOL (this file) from version X to X+1. - Move the old FunctionNode code into 'skops/io/old/_general_vX.py', where 'X' is the old protocol. - Register the _general_vX.FunctionNode in NODE_TYPE_MAPPING inside of _persist.py. +- Write a test in test_persist_old.py that shows that the old state can + still be loaded. Look at test_persist_old.test_function_v0 for inspiration. Now, if a user loads a FunctionNode state with version X using skops with -version Y, the old code will be used instead of the new one. For all other +version Y>X, the old code will be used instead of the new one. For all other node types, if there is no loader for version X, skops will automatically use version Y instead. + """ PROTOCOL = 1 diff --git a/skops/io/tests/_utils.py b/skops/io/tests/_utils.py index 9081a9fc..905367c8 100644 --- a/skops/io/tests/_utils.py +++ b/skops/io/tests/_utils.py @@ -1,5 +1,8 @@ +import io +import json import sys import warnings +from zipfile import ZipFile import numpy as np from scipy import sparse @@ -173,3 +176,62 @@ def assert_method_outputs_equal(estimator, loaded, X): X_out1 = getattr(estimator, method)(X) X_out2 = getattr(loaded, method)(X) assert_allclose_dense_sparse(X_out1, X_out2, err_msg=err_msg, atol=ATOL) + + +def downgrade_state(*, data: bytes, keys: list[str], old_state: dict, protocol: int): + """Function to downgrade the persistence state to an older version. + + This function is important for testing upgrades to the skops persistence + protocol. When an upgrade is made, we add a test to ensure that the old + state can still be loaded successfully. For this, we need to generate a + state that looks like it came from the previous protocol. This function + helps doing that. + + The caller should pass the new state, the path to the sub-state that needs + to be downgraded, the actual old state to be downgraded to, and the protocol + of that old version. Then this function will replace the new state with the + old state and insert the old protocol number. It also adds an ``__id__`` + field, which is expected for memoization. + + Parameters + ---------- + data : bytes + The old state, as generated by ``skops.io.dumps``. + + keys : list of str + The keys that lead to the old state. E.g. if we want to replace + ``state["foo"]["bar"]``, then keys should be ``["foo", "bar"]``. + + old_state : dict + The old state, as would be produced by the old ``get_state`` function. + + protocol : int + The protocol number corresponding to the old state. + + Returns + ------- + bytes + The old state, as would have been dumped by ``skops.io.dumps``. + + """ + # load from bytes + with ZipFile(io.BytesIO(data), "r") as zip_file: + schema = json.loads(zip_file.read("schema.json")) + + # replace schema + schema["protocol"] = protocol + + # replace state using old state + state = schema + for key in keys[:-1]: + state = state[key] + state[keys[-1]] = old_state + + # there has to be an __id__ field for memoization + state[keys[-1]]["__id__"] = id(schema) + + # dump into bytes + buffer = io.BytesIO() + with ZipFile(buffer, "w") as zip_file: + zip_file.writestr("schema.json", json.dumps(schema, indent=2)) + return buffer.getbuffer().tobytes() From b0f83fac34fee1f8d444496cd49cc9e314bab1c9 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Tue, 21 Mar 2023 17:59:13 +0100 Subject: [PATCH 6/8] Fix annotation issue, add missing file --- skops/io/tests/_utils.py | 2 + skops/io/tests/test_persist_old.py | 68 ++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+) create mode 100644 skops/io/tests/test_persist_old.py diff --git a/skops/io/tests/_utils.py b/skops/io/tests/_utils.py index 905367c8..fa171d32 100644 --- a/skops/io/tests/_utils.py +++ b/skops/io/tests/_utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import io import json import sys diff --git a/skops/io/tests/test_persist_old.py b/skops/io/tests/test_persist_old.py new file mode 100644 index 00000000..278423b3 --- /dev/null +++ b/skops/io/tests/test_persist_old.py @@ -0,0 +1,68 @@ +"""Persistence tests for old versions of the protocol""" + +from __future__ import annotations + +import numpy as np +import pytest +from scipy import special +from sklearn.preprocessing import FunctionTransformer + +from skops.io import dumps, loads +from skops.io._utils import get_module +from skops.io.tests._utils import ( + assert_method_outputs_equal, + assert_params_equal, + downgrade_state, +) + +############# +# VERSION 0 # +############# + + +def dummy_func(X): + return X + + +@pytest.mark.parametrize("func", [np.sqrt, len, special.exp10, dummy_func]) +def test_function_v0(func): + call_count = 0 + + # function_get_state as it was for protocol 0 + def old_function_get_state(obj, save_context): + # added for testing + nonlocal call_count + call_count += 1 + # end + + res = { + "__class__": obj.__class__.__name__, + "__module__": get_module(obj), + "__loader__": "FunctionNode", + "content": { + "module_path": get_module(obj), + "function": obj.__name__, + }, + } + return res + + estimator = FunctionTransformer(func=func) + X, y = [0, 1], [2, 3] + estimator.fit(X, y) + + dumped = dumps(estimator) + # importent: downgrade the state to mimic older version + downgraded = downgrade_state( + data=dumped, + keys=["content", "content", "func"], + old_state=old_function_get_state(func, None), + protocol=0, + ) + loaded = loads(downgraded, trusted=True) + + # sanity check: ensure that the old get_state function was really called + assert call_count == 1 + + # check that loaded estimator is identical + assert_params_equal(estimator.__dict__, loaded.__dict__) + assert_method_outputs_equal(estimator, loaded, X) From 1e9415e010580a2bada182cb46b2f254466ad526 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Wed, 22 Mar 2023 14:59:53 +0100 Subject: [PATCH 7/8] Give an example for downgrade_state Also, add "# pragma: no cover" where it makes sense. --- skops/io/old/_general_v0.py | 4 ++-- skops/io/tests/_utils.py | 21 ++++++++++++++++++++- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/skops/io/old/_general_v0.py b/skops/io/old/_general_v0.py index a2eed3a2..5580e63e 100644 --- a/skops/io/old/_general_v0.py +++ b/skops/io/old/_general_v0.py @@ -27,14 +27,14 @@ def _construct(self): self.children["content"]["function"], ) - def _get_function_name(self) -> str: + def _get_function_name(self) -> str: # pragma: no cover return ( self.children["content"]["module_path"] + "." + self.children["content"]["function"] ) - def get_unsafe_set(self) -> set[str]: + def get_unsafe_set(self) -> set[str]: # pragma: no cover if (self.trusted is True) or (self._get_function_name() in self.trusted): return set() diff --git a/skops/io/tests/_utils.py b/skops/io/tests/_utils.py index fa171d32..ce9c45d9 100644 --- a/skops/io/tests/_utils.py +++ b/skops/io/tests/_utils.py @@ -181,7 +181,7 @@ def assert_method_outputs_equal(estimator, loaded, X): def downgrade_state(*, data: bytes, keys: list[str], old_state: dict, protocol: int): - """Function to downgrade the persistence state to an older version. + """Function to downgrade the persisted state of a skops object. This function is important for testing upgrades to the skops persistence protocol. When an upgrade is made, we add a test to ensure that the old @@ -195,6 +195,25 @@ def downgrade_state(*, data: bytes, keys: list[str], old_state: dict, protocol: old state and insert the old protocol number. It also adds an ``__id__`` field, which is expected for memoization. + Here is an example of how to use it: + + .. code:: python + + estimator = ... + # get the state of the object using current protocol + dumped = sio.dumps(estimator) + # let's assume that estimator.foo.bar was changed + keys = ["foo", "bar"] + old_state = old_get_state_function(bar) + downgraded = downgrad_state( + data=dumped, + keys=keys, + old_state=old_state + protocol=current_protocol - 1, + ) + # check that this does not raise an error: + sio.loads(downgrade, trusted=...) + Parameters ---------- data : bytes From 333f9d11dc1ab620fa684d5823939a7b765a3a45 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Wed, 22 Mar 2023 16:47:53 +0100 Subject: [PATCH 8/8] Add test for perstisting functions Test for old and new protocol version are basically identical. --- skops/io/tests/test_persist.py | 18 ++++++++++++++++++ skops/io/tests/test_persist_old.py | 2 +- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/skops/io/tests/test_persist.py b/skops/io/tests/test_persist.py index 5bc9d643..19a6c7eb 100644 --- a/skops/io/tests/test_persist.py +++ b/skops/io/tests/test_persist.py @@ -953,3 +953,21 @@ def test_persist_operator_raises_untrusted(op): est = FunctionTransformer(func) with pytest.raises(UntrustedTypesFoundException, match=name): loads(dumps(est), trusted=False) + + +def dummy_func(X): + return X + + +@pytest.mark.parametrize("func", [np.sqrt, len, special.exp10, dummy_func]) +def test_persist_function(func): + estimator = FunctionTransformer(func=func) + X, y = [0, 1], [2, 3] + estimator.fit(X, y) + + dumped = dumps(estimator) + loaded = loads(dumped, trusted=True) + + # check that loaded estimator is identical + assert_params_equal(estimator.__dict__, loaded.__dict__) + assert_method_outputs_equal(estimator, loaded, X) diff --git a/skops/io/tests/test_persist_old.py b/skops/io/tests/test_persist_old.py index 278423b3..69f297fa 100644 --- a/skops/io/tests/test_persist_old.py +++ b/skops/io/tests/test_persist_old.py @@ -25,7 +25,7 @@ def dummy_func(X): @pytest.mark.parametrize("func", [np.sqrt, len, special.exp10, dummy_func]) -def test_function_v0(func): +def test_persist_function_v0(func): call_count = 0 # function_get_state as it was for protocol 0