From 2c7faf4055e3f8736326bd9ffe66c7960b5fa1cf Mon Sep 17 00:00:00 2001 From: = Date: Fri, 4 Nov 2022 20:35:39 +0000 Subject: [PATCH 01/14] First pass at LoadState with src --- skops/io/_dispatch.py | 18 +++++++++++++-- skops/io/_general.py | 42 +++++++++++++++++----------------- skops/io/_numpy.py | 28 ++++++++++++----------- skops/io/_persist.py | 14 +++++++----- skops/io/_scipy.py | 6 ++--- skops/io/_sklearn.py | 20 ++++++++-------- skops/io/_utils.py | 23 +++++++++++++++++++ skops/io/tests/test_persist.py | 10 ++++---- 8 files changed, 101 insertions(+), 60 deletions(-) diff --git a/skops/io/_dispatch.py b/skops/io/_dispatch.py index e0ae9a96..1846625f 100644 --- a/skops/io/_dispatch.py +++ b/skops/io/_dispatch.py @@ -2,14 +2,21 @@ import json +from skops.io._utils import LoadState + GET_INSTANCE_MAPPING = {} # type: ignore -def get_instance(state, src): +def get_instance(state, load_state: LoadState): """Create instance based on the state, using json if possible""" if state.get("is_json"): return json.loads(state["content"]) + saved_id = state.get("__id__") + if saved_id and saved_id in load_state.memo: + # an instance has already been loaded, just return the loaded instance + return load_state.get_instance(saved_id) + try: get_instance_func = GET_INSTANCE_MAPPING[state["__loader__"]] except KeyError: @@ -17,4 +24,11 @@ def get_instance(state, src): raise TypeError( f" Can't find loader {state['__loader__']} for type {type_name}." ) - return get_instance_func(state, src) + + loaded_obj = get_instance_func(state, load_state) + + # hold reference to obj in case same instance encountered again in save state + if saved_id: + load_state.memoize(loaded_obj, saved_id) + + return loaded_obj diff --git a/skops/io/_general.py b/skops/io/_general.py index c0594319..418615ca 100644 --- a/skops/io/_general.py +++ b/skops/io/_general.py @@ -8,7 +8,7 @@ import numpy as np from ._dispatch import get_instance -from ._utils import SaveState, _import_obj, get_module, get_state, gettype +from ._utils import LoadState, SaveState, _import_obj, get_module, get_state, gettype from .exceptions import UnsupportedTypeException @@ -33,11 +33,11 @@ def dict_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def dict_get_instance(state, src): +def dict_get_instance(state, load_state: LoadState): content = gettype(state)() - key_types = get_instance(state["key_types"], src) + key_types = get_instance(state["key_types"], load_state) for k_type, item in zip(key_types, state["content"].items()): - content[k_type(item[0])] = get_instance(item[1], src) + content[k_type(item[0])] = get_instance(item[1], load_state) return content @@ -54,10 +54,10 @@ def list_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def list_get_instance(state, src): +def list_get_instance(state, load_state: LoadState): content = gettype(state)() for value in state["content"]: - content.append(get_instance(value, src)) + content.append(get_instance(value, load_state)) return content @@ -72,7 +72,7 @@ def tuple_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def tuple_get_instance(state, src): +def tuple_get_instance(state, load_state: LoadState): # Returns a tuple or a namedtuple instance. def isnamedtuple(t): # This is needed since namedtuples need to have the args when @@ -86,7 +86,7 @@ def isnamedtuple(t): return all(type(n) == str for n in f) cls = gettype(state) - content = tuple(get_instance(value, src) for value in state["content"]) + content = tuple(get_instance(value, load_state) for value in state["content"]) if isnamedtuple(cls): return cls(*content) @@ -106,7 +106,7 @@ def function_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def function_get_instance(state, src): +def function_get_instance(state, load_state: LoadState): loaded = _import_obj(state["content"]["module_path"], state["content"]["function"]) return loaded @@ -127,14 +127,14 @@ def partial_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def partial_get_instance(state, src): +def partial_get_instance(state, load_state: LoadState): content = state["content"] - func = get_instance(content["func"], src) - args = get_instance(content["args"], src) - kwds = get_instance(content["kwds"], src) - namespace = get_instance(content["namespace"], src) + func = get_instance(content["func"], load_state) + args = get_instance(content["args"], load_state) + kwds = get_instance(content["kwds"], load_state) + namespace = get_instance(content["namespace"], load_state) instance = partial(func, *args, **kwds) # always use partial, not a subclass - instance.__setstate__((func, args, kwds, namespace)) + instance.__setstate__((func, args, kwds, namespace)) # type: ignore return instance @@ -153,7 +153,7 @@ def type_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def type_get_instance(state, src): +def type_get_instance(state, load_state: LoadState): loaded = _import_obj(state["content"]["__module__"], state["content"]["__class__"]) return loaded @@ -172,7 +172,7 @@ def slice_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def slice_get_instance(state, src): +def slice_get_instance(state, load_state: LoadState): start = state["content"]["start"] stop = state["content"]["stop"] step = state["content"]["step"] @@ -218,7 +218,7 @@ def object_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def object_get_instance(state, src): +def object_get_instance(state, load_state: LoadState): if state.get("is_json", False): return json.loads(state["content"]) @@ -233,7 +233,7 @@ def object_get_instance(state, src): if not content: # nothing more to do return instance - attrs = get_instance(content, src) + attrs = get_instance(content, load_state) if hasattr(instance, "__setstate__"): instance.__setstate__(attrs) else: @@ -260,8 +260,8 @@ def method_get_state(obj: Any, save_state: SaveState): return res -def method_get_instance(state, src): - loaded_obj = object_get_instance(state["content"]["obj"], src) +def method_get_instance(state, load_state: LoadState): + loaded_obj = object_get_instance(state["content"]["obj"], load_state) method = getattr(loaded_obj, state["content"]["func"]) return method diff --git a/skops/io/_numpy.py b/skops/io/_numpy.py index 4d1d5b98..2a1ae789 100644 --- a/skops/io/_numpy.py +++ b/skops/io/_numpy.py @@ -7,7 +7,7 @@ from ._dispatch import get_instance from ._general import function_get_instance -from ._utils import SaveState, _import_obj, get_module, get_state +from ._utils import LoadState, SaveState, _import_obj, get_module, get_state from .exceptions import UnsupportedTypeException @@ -50,10 +50,12 @@ def ndarray_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def ndarray_get_instance(state, src): +def ndarray_get_instance(state, load_state: LoadState): # Dealing with a regular numpy array, where dtype != object if state["type"] == "numpy": - val = np.load(io.BytesIO(src.read(state["file"])), allow_pickle=False) + val = np.load( + io.BytesIO(load_state.src.read(state["file"])), allow_pickle=False + ) # Coerce type, because it may not be conserved by np.save/load. E.g. a # scalar will be loaded as a 0-dim array. if state["__class__"] != "ndarray": @@ -63,8 +65,8 @@ def ndarray_get_instance(state, src): # We explicitly set the dtype to "O" since we only save object arrays in # json. - shape = get_instance(state["shape"], src) - tmp = [get_instance(s, src) for s in state["content"]] + shape = get_instance(state["shape"], load_state) + tmp = [get_instance(s, load_state) for s in state["content"]] # TODO: this is a hack to get the correct shape of the array. We should # find _a better way_ to do this. if len(shape) == 1: @@ -89,9 +91,9 @@ def maskedarray_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def maskedarray_get_instance(state, src): - data = get_instance(state["content"]["data"], src) - mask = get_instance(state["content"]["mask"], src) +def maskedarray_get_instance(state, load_state: LoadState): + data = get_instance(state["content"]["data"], load_state) + mask = get_instance(state["content"]["mask"], load_state) return np.ma.MaskedArray(data, mask) @@ -106,10 +108,10 @@ def random_state_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def random_state_get_instance(state, src): +def random_state_get_instance(state, load_state: LoadState): cls = _import_obj(state["__module__"], state["__class__"]) random_state = cls() - content = get_instance(state["content"], src) + content = get_instance(state["content"], load_state) random_state.set_state(content) return random_state @@ -125,7 +127,7 @@ def random_generator_get_state(obj: Any, save_state: SaveState) -> dict[str, Any return res -def random_generator_get_instance(state, src): +def random_generator_get_instance(state, load_state: LoadState): # first restore the state of the bit generator bit_generator_state = state["content"]["bit_generator"] bit_generator = _import_obj("numpy.random", bit_generator_state["bit_generator"])() @@ -166,10 +168,10 @@ def dtype_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def dtype_get_instance(state, src): +def dtype_get_instance(state, load_state: LoadState): # we use numpy's internal save mechanism to store the dtype by # saving/loading an empty array with that dtype. - tmp = ndarray_get_instance(state["content"], src) + tmp = ndarray_get_instance(state["content"], load_state) return tmp.dtype diff --git a/skops/io/_persist.py b/skops/io/_persist.py index 9144da08..7b4cfb65 100644 --- a/skops/io/_persist.py +++ b/skops/io/_persist.py @@ -8,7 +8,7 @@ import skops from ._dispatch import GET_INSTANCE_MAPPING, get_instance -from ._utils import SaveState, _get_state, get_state +from ._utils import LoadState, SaveState, _get_state, get_state # We load the dispatch functions from the corresponding modules and register # them. @@ -114,8 +114,9 @@ def load(file): """ with ZipFile(file, "r") as input_zip: - schema = input_zip.read("schema.json") - instance = get_instance(json.loads(schema), input_zip) + schema = json.loads(input_zip.read("schema.json")) + load_state = LoadState(src=input_zip) + instance = get_instance(schema, load_state=load_state) return instance @@ -139,7 +140,8 @@ def loads(data): if isinstance(data, str): raise TypeError("Can't load skops format from string, pass bytes") - with ZipFile(io.BytesIO(data), "r") as zip_file: - schema = json.loads(zip_file.read("schema.json")) - instance = get_instance(schema, src=zip_file) + with ZipFile(io.BytesIO(data), "r") as input_zip: + schema = json.loads(input_zip.read("schema.json")) + load_state = LoadState(src=input_zip) + instance = get_instance(schema, load_state=load_state) return instance diff --git a/skops/io/_scipy.py b/skops/io/_scipy.py index 305d30dc..166a1717 100644 --- a/skops/io/_scipy.py +++ b/skops/io/_scipy.py @@ -5,7 +5,7 @@ from scipy.sparse import load_npz, save_npz, spmatrix -from ._utils import SaveState, get_module +from ._utils import LoadState, SaveState, get_module def sparse_matrix_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: @@ -31,7 +31,7 @@ def sparse_matrix_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def sparse_matrix_get_instance(state, src): +def sparse_matrix_get_instance(state, load_state: LoadState): if state["type"] != "scipy": raise TypeError( f"Cannot load object of type {state['__module__']}.{state['__class__']}" @@ -39,7 +39,7 @@ def sparse_matrix_get_instance(state, src): # scipy load_npz uses numpy.save with allow_pickle=False under the hood, so # we're safe using it - val = load_npz(io.BytesIO(src.read(state["file"]))) + val = load_npz(io.BytesIO(load_state.src.read(state["file"]))) return val diff --git a/skops/io/_sklearn.py b/skops/io/_sklearn.py index 6cde5463..89fc47de 100644 --- a/skops/io/_sklearn.py +++ b/skops/io/_sklearn.py @@ -87,12 +87,12 @@ def reduce_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def reduce_get_instance(state, src, constructor): +def reduce_get_instance(state, load_state, constructor): reduce = state["__reduce__"] - args = get_instance(reduce["args"], src) + args = get_instance(reduce["args"], load_state) instance = constructor(*args) - attrs = get_instance(state["content"], src) + attrs = get_instance(state["content"], load_state) if not attrs: # nothing more to do return instance @@ -116,8 +116,8 @@ def tree_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return state -def tree_get_instance(state, src): - return reduce_get_instance(state, src, constructor=Tree) +def tree_get_instance(state, load_state): + return reduce_get_instance(state, load_state, constructor=Tree) def sgd_loss_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: @@ -126,11 +126,11 @@ def sgd_loss_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return state -def sgd_loss_get_instance(state, src): +def sgd_loss_get_instance(state, load_state): cls = gettype(state) if cls not in ALLOWED_SGD_LOSSES: raise UnsupportedTypeException(f"Expected LossFunction, got {cls}") - return reduce_get_instance(state, src, constructor=cls) + return reduce_get_instance(state, load_state, constructor=cls) # TODO: remove once support for sklearn<1.2 is dropped. @@ -152,11 +152,11 @@ def _DictWithDeprecatedKeys_get_state( # TODO: remove once support for sklearn<1.2 is dropped. -def _DictWithDeprecatedKeys_get_instance(state, src): +def _DictWithDeprecatedKeys_get_instance(state, load_state): # _DictWithDeprecatedKeys is just a wrapper for dict - content = dict_get_instance(state["content"]["main"], src) + content = dict_get_instance(state["content"]["main"], load_state) deprecated_key_to_new_key = dict_get_instance( - state["content"]["_deprecated_key_to_new_key"], src + state["content"]["_deprecated_key_to_new_key"], load_state ) res = _DictWithDeprecatedKeys(**content) res._deprecated_key_to_new_key = deprecated_key_to_new_key diff --git a/skops/io/_utils.py b/skops/io/_utils.py index 2b6f9749..e25678db 100644 --- a/skops/io/_utils.py +++ b/skops/io/_utils.py @@ -125,6 +125,29 @@ def clear_memo(self) -> None: self.memo.clear() +@dataclass(frozen=True) +class LoadState: + """State required for loading an object + + This state is passed to each ``get_instance_*`` function. + + Parameters + ---------- + src: zipfile.ZipFile + The zip file the target object is saved in + + """ + + src: ZipFile + memo: dict[int, Any] = field(default_factory=dict) + + def memoize(self, obj: Any, id: int) -> None: + self.memo[id] = obj + + def get_instance(self, id: int) -> Any: + return self.memo.get(id) + + @singledispatch def _get_state(obj, save_state): # This function should never be called directly. Instead, it is used to diff --git a/skops/io/tests/test_persist.py b/skops/io/tests/test_persist.py index b999487b..4ff84486 100644 --- a/skops/io/tests/test_persist.py +++ b/skops/io/tests/test_persist.py @@ -56,7 +56,7 @@ from skops.io import dump, dumps, load, loads from skops.io._dispatch import GET_INSTANCE_MAPPING, get_instance from skops.io._sklearn import UNSUPPORTED_TYPES -from skops.io._utils import _get_state, get_state +from skops.io._utils import LoadState, _get_state, get_state from skops.io.exceptions import UnsupportedTypeException # Default settings for X @@ -96,16 +96,16 @@ def wrapper(obj, save_state): def debug_get_instance(func): # check consistency of argument names and input type signature = inspect.signature(func) - assert list(signature.parameters.keys()) == ["state", "src"] + assert list(signature.parameters.keys()) == ["state", "load_state"] @wraps(func) - def wrapper(state, src): + def wrapper(state, load_state): assert "__class__" in state assert "__module__" in state assert "__loader__" in state - assert isinstance(src, ZipFile) + assert isinstance(load_state, LoadState) - result = func(state, src) + result = func(state, load_state) return result return wrapper From 5109f147b78c651b1a3df3e73c50a64354cb5eae Mon Sep 17 00:00:00 2001 From: = Date: Fri, 4 Nov 2022 21:10:45 +0000 Subject: [PATCH 02/14] Fix old xfail test for bound method --- skops/io/_general.py | 7 +++++-- skops/io/tests/test_persist.py | 3 --- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/skops/io/_general.py b/skops/io/_general.py index 418615ca..fd725ced 100644 --- a/skops/io/_general.py +++ b/skops/io/_general.py @@ -98,6 +98,7 @@ def function_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: "__class__": obj.__class__.__name__, "__module__": get_module(obj), "__loader__": "function_get_instance", + "__id__": id(obj), "content": { "module_path": get_module(obj), "function": obj.__name__, @@ -192,6 +193,7 @@ def object_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: "__loader__": "none", "content": obj_str, "is_json": True, + "__id__": id(obj), } except Exception: pass @@ -200,6 +202,7 @@ def object_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), "__loader__": "object_get_instance", + "__id__": id(obj), } # __getstate__ takes priority over __dict__, and if non exist, we only save @@ -251,17 +254,17 @@ def method_get_state(obj: Any, save_state: SaveState): "__class__": obj.__class__.__name__, "__module__": get_module(obj), "__loader__": "method_get_instance", + "__id__": id(obj), "content": { "func": obj.__func__.__name__, "obj": get_state(obj.__self__, save_state), }, } - return res def method_get_instance(state, load_state: LoadState): - loaded_obj = object_get_instance(state["content"]["obj"], load_state) + loaded_obj = get_instance(state["content"]["obj"], load_state) method = getattr(loaded_obj, state["content"]["func"]) return method diff --git a/skops/io/tests/test_persist.py b/skops/io/tests/test_persist.py index 4ff84486..2dd0aa79 100644 --- a/skops/io/tests/test_persist.py +++ b/skops/io/tests/test_persist.py @@ -867,9 +867,6 @@ def test_when_object_is_changed_after_init_works_as_expected(self): self.assert_transformer_persisted_correctly(loaded_transformer, transformer) self.assert_bound_method_holder_persisted_correctly(obj, loaded_obj) - @pytest.mark.xfail( - reason="Can't load an object as a single instance if referenced multiple times" - ) def test_works_when_given_multiple_bound_methods_attached_to_single_instance(self): obj = _BoundMethodHolder(object_state="") From 2854fe878504d1781aa6a9ead2af349178e70bf4 Mon Sep 17 00:00:00 2001 From: = Date: Mon, 7 Nov 2022 20:25:09 +0000 Subject: [PATCH 03/14] Add persist id decorator --- skops/io/_general.py | 21 ++++++++++++++++----- skops/io/_numpy.py | 7 ++++++- skops/io/_persist.py | 3 ++- skops/io/_utils.py | 22 ++++++++++++++++++++++ 4 files changed, 46 insertions(+), 7 deletions(-) diff --git a/skops/io/_general.py b/skops/io/_general.py index fd725ced..5d72b897 100644 --- a/skops/io/_general.py +++ b/skops/io/_general.py @@ -8,10 +8,19 @@ import numpy as np from ._dispatch import get_instance -from ._utils import LoadState, SaveState, _import_obj, get_module, get_state, gettype +from ._utils import ( + LoadState, + SaveState, + _import_obj, + get_module, + get_state, + gettype, + persist_id, +) from .exceptions import UnsupportedTypeException +@persist_id def dict_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, @@ -61,6 +70,7 @@ def list_get_instance(state, load_state: LoadState): return content +@persist_id def tuple_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, @@ -93,12 +103,12 @@ def isnamedtuple(t): return content +@persist_id def function_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, "__module__": get_module(obj), "__loader__": "function_get_instance", - "__id__": id(obj), "content": { "module_path": get_module(obj), "function": obj.__name__, @@ -112,6 +122,7 @@ def function_get_instance(state, load_state: LoadState): return loaded +@persist_id def partial_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: _, _, (func, args, kwds, namespace) = obj.__reduce__() res = { @@ -159,6 +170,7 @@ def type_get_instance(state, load_state: LoadState): return loaded +@persist_id def slice_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, @@ -180,6 +192,7 @@ def slice_get_instance(state, load_state: LoadState): return slice(start, stop, step) +@persist_id def object_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: # This method is for objects which can either be persisted with json, or # the ones for which we can get/set attributes through @@ -193,7 +206,6 @@ def object_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: "__loader__": "none", "content": obj_str, "is_json": True, - "__id__": id(obj), } except Exception: pass @@ -202,7 +214,6 @@ def object_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), "__loader__": "object_get_instance", - "__id__": id(obj), } # __getstate__ takes priority over __dict__, and if non exist, we only save @@ -245,6 +256,7 @@ def object_get_instance(state, load_state: LoadState): return instance +@persist_id def method_get_state(obj: Any, save_state: SaveState): # This method is used to persist bound methods, which are # dependent on a specific instance of an object. @@ -254,7 +266,6 @@ def method_get_state(obj: Any, save_state: SaveState): "__class__": obj.__class__.__name__, "__module__": get_module(obj), "__loader__": "method_get_instance", - "__id__": id(obj), "content": { "func": obj.__func__.__name__, "obj": get_state(obj.__self__, save_state), diff --git a/skops/io/_numpy.py b/skops/io/_numpy.py index 2a1ae789..2e75aab5 100644 --- a/skops/io/_numpy.py +++ b/skops/io/_numpy.py @@ -7,7 +7,7 @@ from ._dispatch import get_instance from ._general import function_get_instance -from ._utils import LoadState, SaveState, _import_obj, get_module, get_state +from ._utils import LoadState, SaveState, _import_obj, get_module, get_state, persist_id from .exceptions import UnsupportedTypeException @@ -78,6 +78,7 @@ def ndarray_get_instance(state, load_state: LoadState): return val +@persist_id def maskedarray_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, @@ -97,6 +98,7 @@ def maskedarray_get_instance(state, load_state: LoadState): return np.ma.MaskedArray(data, mask) +@persist_id def random_state_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: content = get_state(obj.get_state(legacy=False), save_state) res = { @@ -116,6 +118,7 @@ def random_state_get_instance(state, load_state: LoadState): return random_state +@persist_id def random_generator_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: bit_generator_state = obj.bit_generator.state res = { @@ -142,6 +145,7 @@ def random_generator_get_instance(state, load_state: LoadState): # For numpy.ufunc we need to get the type from the type's module, but for other # functions we get it from objet's module directly. Therefore sett a especial # get_state method for them here. The load is the same as other functions. +@persist_id def ufunc_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, # ufunc @@ -155,6 +159,7 @@ def ufunc_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res +@persist_id def dtype_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: # we use numpy's internal save mechanism to store the dtype by # saving/loading an empty array with that dtype. diff --git a/skops/io/_persist.py b/skops/io/_persist.py index 7b4cfb65..bb4f5103 100644 --- a/skops/io/_persist.py +++ b/skops/io/_persist.py @@ -32,7 +32,6 @@ def _save(obj): state["protocol"] = save_state.protocol state["_skops_version"] = skops.__version__ - zip_file.writestr("schema.json", json.dumps(state, indent=2)) return buffer @@ -117,6 +116,7 @@ def load(file): schema = json.loads(input_zip.read("schema.json")) load_state = LoadState(src=input_zip) instance = get_instance(schema, load_state=load_state) + load_state.clear_memo() return instance @@ -144,4 +144,5 @@ def loads(data): schema = json.loads(input_zip.read("schema.json")) load_state = LoadState(src=input_zip) instance = get_instance(schema, load_state=load_state) + load_state.clear_memo() return instance diff --git a/skops/io/_utils.py b/skops/io/_utils.py index e25678db..662ffd31 100644 --- a/skops/io/_utils.py +++ b/skops/io/_utils.py @@ -87,6 +87,25 @@ def get_module(obj): DEFAULT_PROTOCOL = 0 +def persist_id(func): + """Wrapper to add __id__ to states we want to be able to persist as single + instances. + + Intended to be used as a decorator. + + NB: Not all get_state functions should include ids. Ephemeral objects + have their IDs reused, and so storing some objects (like dicts, lists, arrays + etc.) can cause problems. + """ + + def wrapper(obj: Any, save_state: SaveState): + result = func(obj, save_state) + result["__id__"] = id(obj) + return result + + return wrapper + + @dataclass(frozen=True) class SaveState: """State required for saving the objects @@ -147,6 +166,9 @@ def memoize(self, obj: Any, id: int) -> None: def get_instance(self, id: int) -> Any: return self.memo.get(id) + def clear_memo(self): + self.memo.clear() + @singledispatch def _get_state(obj, save_state): From 9945c481e5ba57cbc6e042cb8f5250efed565313 Mon Sep 17 00:00:00 2001 From: = Date: Mon, 7 Nov 2022 20:34:56 +0000 Subject: [PATCH 04/14] Update test to fix problem with clashing test --- skops/io/_dispatch.py | 1 + skops/io/_persist.py | 2 -- skops/io/_utils.py | 5 +---- skops/io/tests/test_persist.py | 2 +- 4 files changed, 3 insertions(+), 7 deletions(-) diff --git a/skops/io/_dispatch.py b/skops/io/_dispatch.py index 1846625f..d35d3b09 100644 --- a/skops/io/_dispatch.py +++ b/skops/io/_dispatch.py @@ -13,6 +13,7 @@ def get_instance(state, load_state: LoadState): return json.loads(state["content"]) saved_id = state.get("__id__") + if saved_id and saved_id in load_state.memo: # an instance has already been loaded, just return the loaded instance return load_state.get_instance(saved_id) diff --git a/skops/io/_persist.py b/skops/io/_persist.py index bb4f5103..c297acda 100644 --- a/skops/io/_persist.py +++ b/skops/io/_persist.py @@ -116,7 +116,6 @@ def load(file): schema = json.loads(input_zip.read("schema.json")) load_state = LoadState(src=input_zip) instance = get_instance(schema, load_state=load_state) - load_state.clear_memo() return instance @@ -144,5 +143,4 @@ def loads(data): schema = json.loads(input_zip.read("schema.json")) load_state = LoadState(src=input_zip) instance = get_instance(schema, load_state=load_state) - load_state.clear_memo() return instance diff --git a/skops/io/_utils.py b/skops/io/_utils.py index 662ffd31..b95b0f88 100644 --- a/skops/io/_utils.py +++ b/skops/io/_utils.py @@ -94,7 +94,7 @@ def persist_id(func): Intended to be used as a decorator. NB: Not all get_state functions should include ids. Ephemeral objects - have their IDs reused, and so storing some objects (like dicts, lists, arrays + have their IDs reused, and so storing some objects (some dicts, lists, arrays etc.) can cause problems. """ @@ -166,9 +166,6 @@ def memoize(self, obj: Any, id: int) -> None: def get_instance(self, id: int) -> Any: return self.memo.get(id) - def clear_memo(self): - self.memo.clear() - @singledispatch def _get_state(obj, save_state): diff --git a/skops/io/tests/test_persist.py b/skops/io/tests/test_persist.py index 2dd0aa79..f75b6770 100644 --- a/skops/io/tests/test_persist.py +++ b/skops/io/tests/test_persist.py @@ -789,7 +789,7 @@ def test_get_instance_unknown_type_error_msg(): state["__loader__"] = "this_get_instance_does_not_exist" msg = "Can't find loader this_get_instance_does_not_exist for type builtins.tuple." with pytest.raises(TypeError, match=msg): - get_instance(state, None) + get_instance(state, LoadState(None)) class _BoundMethodHolder: From 4121a35fbff95e2a3c61705efe1a67c95e216fc9 Mon Sep 17 00:00:00 2001 From: = Date: Tue, 8 Nov 2022 18:12:10 +0000 Subject: [PATCH 05/14] Memoize temp objects to avoid ids being reused --- skops/io/_general.py | 17 +---------------- skops/io/_numpy.py | 7 +------ skops/io/_utils.py | 24 ++++-------------------- skops/io/tests/test_persist.py | 6 ++++-- 4 files changed, 10 insertions(+), 44 deletions(-) diff --git a/skops/io/_general.py b/skops/io/_general.py index 5d72b897..c3a96881 100644 --- a/skops/io/_general.py +++ b/skops/io/_general.py @@ -8,19 +8,10 @@ import numpy as np from ._dispatch import get_instance -from ._utils import ( - LoadState, - SaveState, - _import_obj, - get_module, - get_state, - gettype, - persist_id, -) +from ._utils import LoadState, SaveState, _import_obj, get_module, get_state, gettype from .exceptions import UnsupportedTypeException -@persist_id def dict_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, @@ -70,7 +61,6 @@ def list_get_instance(state, load_state: LoadState): return content -@persist_id def tuple_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, @@ -103,7 +93,6 @@ def isnamedtuple(t): return content -@persist_id def function_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, @@ -122,7 +111,6 @@ def function_get_instance(state, load_state: LoadState): return loaded -@persist_id def partial_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: _, _, (func, args, kwds, namespace) = obj.__reduce__() res = { @@ -170,7 +158,6 @@ def type_get_instance(state, load_state: LoadState): return loaded -@persist_id def slice_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, @@ -192,7 +179,6 @@ def slice_get_instance(state, load_state: LoadState): return slice(start, stop, step) -@persist_id def object_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: # This method is for objects which can either be persisted with json, or # the ones for which we can get/set attributes through @@ -256,7 +242,6 @@ def object_get_instance(state, load_state: LoadState): return instance -@persist_id def method_get_state(obj: Any, save_state: SaveState): # This method is used to persist bound methods, which are # dependent on a specific instance of an object. diff --git a/skops/io/_numpy.py b/skops/io/_numpy.py index 2e75aab5..2a1ae789 100644 --- a/skops/io/_numpy.py +++ b/skops/io/_numpy.py @@ -7,7 +7,7 @@ from ._dispatch import get_instance from ._general import function_get_instance -from ._utils import LoadState, SaveState, _import_obj, get_module, get_state, persist_id +from ._utils import LoadState, SaveState, _import_obj, get_module, get_state from .exceptions import UnsupportedTypeException @@ -78,7 +78,6 @@ def ndarray_get_instance(state, load_state: LoadState): return val -@persist_id def maskedarray_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, @@ -98,7 +97,6 @@ def maskedarray_get_instance(state, load_state: LoadState): return np.ma.MaskedArray(data, mask) -@persist_id def random_state_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: content = get_state(obj.get_state(legacy=False), save_state) res = { @@ -118,7 +116,6 @@ def random_state_get_instance(state, load_state: LoadState): return random_state -@persist_id def random_generator_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: bit_generator_state = obj.bit_generator.state res = { @@ -145,7 +142,6 @@ def random_generator_get_instance(state, load_state: LoadState): # For numpy.ufunc we need to get the type from the type's module, but for other # functions we get it from objet's module directly. Therefore sett a especial # get_state method for them here. The load is the same as other functions. -@persist_id def ufunc_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, # ufunc @@ -159,7 +155,6 @@ def ufunc_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -@persist_id def dtype_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: # we use numpy's internal save mechanism to store the dtype by # saving/loading an empty array with that dtype. diff --git a/skops/io/_utils.py b/skops/io/_utils.py index b95b0f88..70a73757 100644 --- a/skops/io/_utils.py +++ b/skops/io/_utils.py @@ -87,25 +87,6 @@ def get_module(obj): DEFAULT_PROTOCOL = 0 -def persist_id(func): - """Wrapper to add __id__ to states we want to be able to persist as single - instances. - - Intended to be used as a decorator. - - NB: Not all get_state functions should include ids. Ephemeral objects - have their IDs reused, and so storing some objects (some dicts, lists, arrays - etc.) can cause problems. - """ - - def wrapper(obj: Any, save_state: SaveState): - result = func(obj, save_state) - result["__id__"] = id(obj) - return result - - return wrapper - - @dataclass(frozen=True) class SaveState: """State required for saving the objects @@ -180,7 +161,10 @@ def get_state(value, save_state): # fails with `get_state`, we try with json.dumps, if that fails, we raise # the original error alongside the json error. try: - return _get_state(value, save_state) + __id__ = save_state.memoize(obj=value) + res = _get_state(value, save_state) + res["__id__"] = __id__ + return res except TypeError as e1: try: return json.dumps(value) diff --git a/skops/io/tests/test_persist.py b/skops/io/tests/test_persist.py index f75b6770..94e54984 100644 --- a/skops/io/tests/test_persist.py +++ b/skops/io/tests/test_persist.py @@ -56,7 +56,7 @@ from skops.io import dump, dumps, load, loads from skops.io._dispatch import GET_INSTANCE_MAPPING, get_instance from skops.io._sklearn import UNSUPPORTED_TYPES -from skops.io._utils import LoadState, _get_state, get_state +from skops.io._utils import LoadState, SaveState, _get_state, get_state from skops.io.exceptions import UnsupportedTypeException # Default settings for X @@ -785,7 +785,9 @@ def test_loads_from_str(): def test_get_instance_unknown_type_error_msg(): - state = get_state(("hi", [123]), None) + val = ("hi", [123]) + save_state = SaveState(None) + state = get_state(val, save_state) state["__loader__"] = "this_get_instance_does_not_exist" msg = "Can't find loader this_get_instance_does_not_exist for type builtins.tuple." with pytest.raises(TypeError, match=msg): From 13110271bd94b7f20a0e37a7115224a839e83cf1 Mon Sep 17 00:00:00 2001 From: = Date: Tue, 8 Nov 2022 18:39:56 +0000 Subject: [PATCH 06/14] Small test for non-bound-method persistance --- skops/io/_utils.py | 2 +- skops/io/tests/test_persist.py | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/skops/io/_utils.py b/skops/io/_utils.py index 12e5dc68..0d9ebd21 100644 --- a/skops/io/_utils.py +++ b/skops/io/_utils.py @@ -158,8 +158,8 @@ def get_state(value, save_state): # fails with `get_state`, we try with json.dumps, if that fails, we raise # the original error alongside the json error. try: - __id__ = save_state.memoize(obj=value) res = _get_state(value, save_state) + __id__ = save_state.memoize(obj=value) res["__id__"] = __id__ return res except TypeError as e1: diff --git a/skops/io/tests/test_persist.py b/skops/io/tests/test_persist.py index 94e54984..fd7549b9 100644 --- a/skops/io/tests/test_persist.py +++ b/skops/io/tests/test_persist.py @@ -959,3 +959,12 @@ def test_disk_and_memory_are_identical(tmp_path): loaded_memory = loads(dumps(estimator)) assert joblib.hash(loaded_disk) == joblib.hash(loaded_memory) + + +def test_when_given_object_referenced_twice_loads_as_one_object(): + some_function = np.array([1, 2]).shape + + transformer = FunctionTransformer(func=some_function, inverse_func=some_function) + loaded_transformer = loads(dumps(transformer)) + + assert loaded_transformer.func is loaded_transformer.inverse_func From d493cd14c581d42266bf4ecf884ec9731a4e280f Mon Sep 17 00:00:00 2001 From: = Date: Tue, 8 Nov 2022 18:49:56 +0000 Subject: [PATCH 07/14] Add parametrized test for multiple reference object --- skops/io/tests/test_persist.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/skops/io/tests/test_persist.py b/skops/io/tests/test_persist.py index fd7549b9..ba1b3472 100644 --- a/skops/io/tests/test_persist.py +++ b/skops/io/tests/test_persist.py @@ -961,10 +961,18 @@ def test_disk_and_memory_are_identical(tmp_path): assert joblib.hash(loaded_disk) == joblib.hash(loaded_memory) -def test_when_given_object_referenced_twice_loads_as_one_object(): - some_function = np.array([1, 2]).shape - - transformer = FunctionTransformer(func=some_function, inverse_func=some_function) - loaded_transformer = loads(dumps(transformer)) +@pytest.mark.parametrize( + "obj", + [ + np.array([1, 2]), + [1, 2, 3], + {1: 1, 2: 2}, + {1, 2, 3}, + np.random.RandomState(42), + ], +) +def test_when_given_object_referenced_twice_loads_as_one_object(obj): + some_thing = {"obj_1": obj, "obj_2": obj} + persisted_thing = loads(dumps(some_thing)) - assert loaded_transformer.func is loaded_transformer.inverse_func + assert persisted_thing["obj_1"] is persisted_thing["obj_2"] From 7bc4f7a63d39106d4d4ace15fbd4a8ae03cee73e Mon Sep 17 00:00:00 2001 From: = Date: Tue, 8 Nov 2022 18:50:48 +0000 Subject: [PATCH 08/14] Rename objects in test for clarity --- skops/io/tests/test_persist.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/skops/io/tests/test_persist.py b/skops/io/tests/test_persist.py index ba1b3472..7f6b47ce 100644 --- a/skops/io/tests/test_persist.py +++ b/skops/io/tests/test_persist.py @@ -972,7 +972,7 @@ def test_disk_and_memory_are_identical(tmp_path): ], ) def test_when_given_object_referenced_twice_loads_as_one_object(obj): - some_thing = {"obj_1": obj, "obj_2": obj} - persisted_thing = loads(dumps(some_thing)) + an_object = {"obj_1": obj, "obj_2": obj} + persisted_object = loads(dumps(an_object)) - assert persisted_thing["obj_1"] is persisted_thing["obj_2"] + assert persisted_object["obj_1"] is persisted_object["obj_2"] From a2798aee893f327464f6af6a5440f810638d8926 Mon Sep 17 00:00:00 2001 From: = Date: Tue, 8 Nov 2022 19:04:14 +0000 Subject: [PATCH 09/14] Reorder some parts of logic to persist JSON and string objects as singletons --- skops/io/_dispatch.py | 24 ++++++++++++------------ skops/io/_utils.py | 6 ++++-- skops/io/tests/test_persist.py | 1 + 3 files changed, 17 insertions(+), 14 deletions(-) diff --git a/skops/io/_dispatch.py b/skops/io/_dispatch.py index d35d3b09..9d31884d 100644 --- a/skops/io/_dispatch.py +++ b/skops/io/_dispatch.py @@ -9,24 +9,24 @@ def get_instance(state, load_state: LoadState): """Create instance based on the state, using json if possible""" - if state.get("is_json"): - return json.loads(state["content"]) saved_id = state.get("__id__") - if saved_id and saved_id in load_state.memo: # an instance has already been loaded, just return the loaded instance return load_state.get_instance(saved_id) - try: - get_instance_func = GET_INSTANCE_MAPPING[state["__loader__"]] - except KeyError: - type_name = f"{state['__module__']}.{state['__class__']}" - raise TypeError( - f" Can't find loader {state['__loader__']} for type {type_name}." - ) - - loaded_obj = get_instance_func(state, load_state) + if state.get("is_json"): + loaded_obj = json.loads(state["content"]) + else: + try: + get_instance_func = GET_INSTANCE_MAPPING[state["__loader__"]] + except KeyError: + type_name = f"{state['__module__']}.{state['__class__']}" + raise TypeError( + f" Can't find loader {state['__loader__']} for type {type_name}." + ) + + loaded_obj = get_instance_func(state, load_state) # hold reference to obj in case same instance encountered again in save state if saved_id: diff --git a/skops/io/_utils.py b/skops/io/_utils.py index 0d9ebd21..e621e236 100644 --- a/skops/io/_utils.py +++ b/skops/io/_utils.py @@ -157,13 +157,15 @@ def get_state(value, save_state): # This is a helper function to try to get the state of an object. If it # fails with `get_state`, we try with json.dumps, if that fails, we raise # the original error alongside the json error. + __id__ = save_state.memoize(obj=value) try: res = _get_state(value, save_state) - __id__ = save_state.memoize(obj=value) res["__id__"] = __id__ return res except TypeError as e1: try: - return json.dumps(value) + res = json.dumps(value) + res["__id__"] = __id__ + return res except Exception as e2: raise e1 from e2 diff --git a/skops/io/tests/test_persist.py b/skops/io/tests/test_persist.py index 7f6b47ce..3f025175 100644 --- a/skops/io/tests/test_persist.py +++ b/skops/io/tests/test_persist.py @@ -968,6 +968,7 @@ def test_disk_and_memory_are_identical(tmp_path): [1, 2, 3], {1: 1, 2: 2}, {1, 2, 3}, + "A string", np.random.RandomState(42), ], ) From d27449c1389462faa3b59b07a8ea835e220c7dd8 Mon Sep 17 00:00:00 2001 From: = Date: Tue, 8 Nov 2022 19:07:12 +0000 Subject: [PATCH 10/14] Reorder try excepts --- skops/io/_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/skops/io/_utils.py b/skops/io/_utils.py index e621e236..20e7be35 100644 --- a/skops/io/_utils.py +++ b/skops/io/_utils.py @@ -158,14 +158,14 @@ def get_state(value, save_state): # fails with `get_state`, we try with json.dumps, if that fails, we raise # the original error alongside the json error. __id__ = save_state.memoize(obj=value) + try: res = _get_state(value, save_state) - res["__id__"] = __id__ - return res except TypeError as e1: try: res = json.dumps(value) - res["__id__"] = __id__ - return res except Exception as e2: raise e1 from e2 + + res["__id__"] = __id__ + return res From 3715d17e30ad5c5ffe69734d787a201431616e6d Mon Sep 17 00:00:00 2001 From: = Date: Wed, 9 Nov 2022 18:13:00 +0000 Subject: [PATCH 11/14] Address PR comments --- skops/io/_dispatch.py | 2 +- skops/io/tests/test_persist.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/skops/io/_dispatch.py b/skops/io/_dispatch.py index 9d31884d..ea2d548d 100644 --- a/skops/io/_dispatch.py +++ b/skops/io/_dispatch.py @@ -11,7 +11,7 @@ def get_instance(state, load_state: LoadState): """Create instance based on the state, using json if possible""" saved_id = state.get("__id__") - if saved_id and saved_id in load_state.memo: + if saved_id in load_state.memo: # an instance has already been loaded, just return the loaded instance return load_state.get_instance(saved_id) diff --git a/skops/io/tests/test_persist.py b/skops/io/tests/test_persist.py index 3f025175..e1ab8da3 100644 --- a/skops/io/tests/test_persist.py +++ b/skops/io/tests/test_persist.py @@ -88,7 +88,7 @@ def wrapper(obj, save_state): assert "__class__" in result assert "__module__" in result assert "__loader__" in result - + assert "__id__" in result return result return wrapper @@ -103,6 +103,7 @@ def wrapper(state, load_state): assert "__class__" in state assert "__module__" in state assert "__loader__" in state + assert "__id__" in state assert isinstance(load_state, LoadState) result = func(state, load_state) From 459e68ad321c3cc2dfbfa4602ebc08e235aa65fa Mon Sep 17 00:00:00 2001 From: = Date: Wed, 9 Nov 2022 18:30:54 +0000 Subject: [PATCH 12/14] Rename SaveState and LoadState --- skops/io/_dispatch.py | 12 ++--- skops/io/_general.py | 87 ++++++++++++++++++---------------- skops/io/_numpy.py | 56 +++++++++++----------- skops/io/_persist.py | 18 +++---- skops/io/_scipy.py | 14 +++--- skops/io/_sklearn.py | 42 ++++++++-------- skops/io/_utils.py | 12 ++--- skops/io/tests/test_persist.py | 25 +++++----- 8 files changed, 137 insertions(+), 129 deletions(-) diff --git a/skops/io/_dispatch.py b/skops/io/_dispatch.py index ea2d548d..02fd3bd2 100644 --- a/skops/io/_dispatch.py +++ b/skops/io/_dispatch.py @@ -2,18 +2,18 @@ import json -from skops.io._utils import LoadState +from skops.io._utils import LoadContext GET_INSTANCE_MAPPING = {} # type: ignore -def get_instance(state, load_state: LoadState): +def get_instance(state, load_context: LoadContext): """Create instance based on the state, using json if possible""" saved_id = state.get("__id__") - if saved_id in load_state.memo: + if saved_id in load_context.memo: # an instance has already been loaded, just return the loaded instance - return load_state.get_instance(saved_id) + return load_context.get_instance(saved_id) if state.get("is_json"): loaded_obj = json.loads(state["content"]) @@ -26,10 +26,10 @@ def get_instance(state, load_state: LoadState): f" Can't find loader {state['__loader__']} for type {type_name}." ) - loaded_obj = get_instance_func(state, load_state) + loaded_obj = get_instance_func(state, load_context) # hold reference to obj in case same instance encountered again in save state if saved_id: - load_state.memoize(loaded_obj, saved_id) + load_context.memoize(loaded_obj, saved_id) return loaded_obj diff --git a/skops/io/_general.py b/skops/io/_general.py index c3a96881..bf858ad2 100644 --- a/skops/io/_general.py +++ b/skops/io/_general.py @@ -8,18 +8,25 @@ import numpy as np from ._dispatch import get_instance -from ._utils import LoadState, SaveState, _import_obj, get_module, get_state, gettype +from ._utils import ( + LoadContext, + SaveContext, + _import_obj, + get_module, + get_state, + gettype, +) from .exceptions import UnsupportedTypeException -def dict_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: +def dict_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), "__loader__": "dict_get_instance", } - key_types = get_state([type(key) for key in obj.keys()], save_state) + key_types = get_state([type(key) for key in obj.keys()], save_context) content = {} for key, value in obj.items(): if isinstance(value, property): @@ -27,21 +34,21 @@ def dict_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: if np.isscalar(key) and hasattr(key, "item"): # convert numpy value to python object key = key.item() # type: ignore - content[key] = get_state(value, save_state) + content[key] = get_state(value, save_context) res["content"] = content res["key_types"] = key_types return res -def dict_get_instance(state, load_state: LoadState): +def dict_get_instance(state, load_context: LoadContext): content = gettype(state)() - key_types = get_instance(state["key_types"], load_state) + key_types = get_instance(state["key_types"], load_context) for k_type, item in zip(key_types, state["content"].items()): - content[k_type(item[0])] = get_instance(item[1], load_state) + content[k_type(item[0])] = get_instance(item[1], load_context) return content -def list_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: +def list_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), @@ -49,30 +56,30 @@ def list_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: } content = [] for value in obj: - content.append(get_state(value, save_state)) + content.append(get_state(value, save_context)) res["content"] = content return res -def list_get_instance(state, load_state: LoadState): +def list_get_instance(state, load_context: LoadContext): content = gettype(state)() for value in state["content"]: - content.append(get_instance(value, load_state)) + content.append(get_instance(value, load_context)) return content -def tuple_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: +def tuple_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), "__loader__": "tuple_get_instance", } - content = tuple(get_state(value, save_state) for value in obj) + content = tuple(get_state(value, save_context) for value in obj) res["content"] = content return res -def tuple_get_instance(state, load_state: LoadState): +def tuple_get_instance(state, load_context: LoadContext): # Returns a tuple or a namedtuple instance. def isnamedtuple(t): # This is needed since namedtuples need to have the args when @@ -86,14 +93,14 @@ def isnamedtuple(t): return all(type(n) == str for n in f) cls = gettype(state) - content = tuple(get_instance(value, load_state) for value in state["content"]) + content = tuple(get_instance(value, load_context) for value in state["content"]) if isnamedtuple(cls): return cls(*content) return content -def function_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: +def function_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, "__module__": get_module(obj), @@ -106,39 +113,39 @@ def function_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def function_get_instance(state, load_state: LoadState): +def function_get_instance(state, load_context: LoadContext): loaded = _import_obj(state["content"]["module_path"], state["content"]["function"]) return loaded -def partial_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: +def partial_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: _, _, (func, args, kwds, namespace) = obj.__reduce__() res = { "__class__": "partial", # don't allow any subclass "__module__": get_module(type(obj)), "__loader__": "partial_get_instance", "content": { - "func": get_state(func, save_state), - "args": get_state(args, save_state), - "kwds": get_state(kwds, save_state), - "namespace": get_state(namespace, save_state), + "func": get_state(func, save_context), + "args": get_state(args, save_context), + "kwds": get_state(kwds, save_context), + "namespace": get_state(namespace, save_context), }, } return res -def partial_get_instance(state, load_state: LoadState): +def partial_get_instance(state, load_context: LoadContext): content = state["content"] - func = get_instance(content["func"], load_state) - args = get_instance(content["args"], load_state) - kwds = get_instance(content["kwds"], load_state) - namespace = get_instance(content["namespace"], load_state) + func = get_instance(content["func"], load_context) + args = get_instance(content["args"], load_context) + kwds = get_instance(content["kwds"], load_context) + namespace = get_instance(content["namespace"], load_context) instance = partial(func, *args, **kwds) # always use partial, not a subclass instance.__setstate__((func, args, kwds, namespace)) # type: ignore return instance -def type_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: +def type_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: # To serialize a type, we first need to set the metadata to tell that it's # a type, then store the type's info itself in the content field. res = { @@ -153,12 +160,12 @@ def type_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def type_get_instance(state, load_state: LoadState): +def type_get_instance(state, load_context: LoadContext): loaded = _import_obj(state["content"]["__module__"], state["content"]["__class__"]) return loaded -def slice_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: +def slice_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), @@ -172,14 +179,14 @@ def slice_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def slice_get_instance(state, load_state: LoadState): +def slice_get_instance(state, load_context: LoadContext): start = state["content"]["start"] stop = state["content"]["stop"] step = state["content"]["step"] return slice(start, stop, step) -def object_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: +def object_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: # This method is for objects which can either be persisted with json, or # the ones for which we can get/set attributes through # __getstate__/__setstate__ or reading/writing to __dict__. @@ -211,14 +218,14 @@ def object_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: else: return res - content = get_state(attrs, save_state) + content = get_state(attrs, save_context) # it's sufficient to store the "content" because we know that this dict can # only have str type keys res["content"] = content return res -def object_get_instance(state, load_state: LoadState): +def object_get_instance(state, load_context: LoadContext): if state.get("is_json", False): return json.loads(state["content"]) @@ -233,7 +240,7 @@ def object_get_instance(state, load_state: LoadState): if not content: # nothing more to do return instance - attrs = get_instance(content, load_state) + attrs = get_instance(content, load_context) if hasattr(instance, "__setstate__"): instance.__setstate__(attrs) else: @@ -242,7 +249,7 @@ def object_get_instance(state, load_state: LoadState): return instance -def method_get_state(obj: Any, save_state: SaveState): +def method_get_state(obj: Any, save_context: SaveContext): # This method is used to persist bound methods, which are # dependent on a specific instance of an object. # It stores the state of the object the method is bound to, @@ -253,19 +260,19 @@ def method_get_state(obj: Any, save_state: SaveState): "__loader__": "method_get_instance", "content": { "func": obj.__func__.__name__, - "obj": get_state(obj.__self__, save_state), + "obj": get_state(obj.__self__, save_context), }, } return res -def method_get_instance(state, load_state: LoadState): - loaded_obj = get_instance(state["content"]["obj"], load_state) +def method_get_instance(state, load_context: LoadContext): + loaded_obj = get_instance(state["content"]["obj"], load_context) method = getattr(loaded_obj, state["content"]["func"]) return method -def unsupported_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: +def unsupported_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: raise UnsupportedTypeException(obj) diff --git a/skops/io/_numpy.py b/skops/io/_numpy.py index 2a1ae789..f8414f10 100644 --- a/skops/io/_numpy.py +++ b/skops/io/_numpy.py @@ -7,11 +7,11 @@ from ._dispatch import get_instance from ._general import function_get_instance -from ._utils import LoadState, SaveState, _import_obj, get_module, get_state +from ._utils import LoadContext, SaveContext, _import_obj, get_module, get_state from .exceptions import UnsupportedTypeException -def ndarray_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: +def ndarray_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), @@ -23,10 +23,10 @@ def ndarray_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: # allow_pickle=False, therefore we convert them to a list and # recursively call get_state on it. if obj.dtype == object: - obj_serialized = get_state(obj.tolist(), save_state) + obj_serialized = get_state(obj.tolist(), save_context) res["content"] = obj_serialized["content"] res["type"] = "json" - res["shape"] = get_state(obj.shape, save_state) + res["shape"] = get_state(obj.shape, save_context) else: data_buffer = io.BytesIO() np.save(data_buffer, obj) @@ -34,10 +34,10 @@ def ndarray_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: # the object id) already exists. If it does, there is no need to # save the object again. Memoizitation is necessary since for # ephemeral objects, the same id might otherwise be reused. - obj_id = save_state.memoize(obj) + obj_id = save_context.memoize(obj) f_name = f"{obj_id}.npy" - if f_name not in save_state.zip_file.namelist(): - save_state.zip_file.writestr(f_name, data_buffer.getbuffer()) + if f_name not in save_context.zip_file.namelist(): + save_context.zip_file.writestr(f_name, data_buffer.getbuffer()) res.update(type="numpy", file=f_name) except ValueError: # Couldn't save the numpy array with either method @@ -50,11 +50,11 @@ def ndarray_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def ndarray_get_instance(state, load_state: LoadState): +def ndarray_get_instance(state, load_context: LoadContext): # Dealing with a regular numpy array, where dtype != object if state["type"] == "numpy": val = np.load( - io.BytesIO(load_state.src.read(state["file"])), allow_pickle=False + io.BytesIO(load_context.src.read(state["file"])), allow_pickle=False ) # Coerce type, because it may not be conserved by np.save/load. E.g. a # scalar will be loaded as a 0-dim array. @@ -65,8 +65,8 @@ def ndarray_get_instance(state, load_state: LoadState): # We explicitly set the dtype to "O" since we only save object arrays in # json. - shape = get_instance(state["shape"], load_state) - tmp = [get_instance(s, load_state) for s in state["content"]] + shape = get_instance(state["shape"], load_context) + tmp = [get_instance(s, load_context) for s in state["content"]] # TODO: this is a hack to get the correct shape of the array. We should # find _a better way_ to do this. if len(shape) == 1: @@ -78,27 +78,27 @@ def ndarray_get_instance(state, load_state: LoadState): return val -def maskedarray_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: +def maskedarray_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), "__loader__": "maskedarray_get_instance", "content": { - "data": get_state(obj.data, save_state), - "mask": get_state(obj.mask, save_state), + "data": get_state(obj.data, save_context), + "mask": get_state(obj.mask, save_context), }, } return res -def maskedarray_get_instance(state, load_state: LoadState): - data = get_instance(state["content"]["data"], load_state) - mask = get_instance(state["content"]["mask"], load_state) +def maskedarray_get_instance(state, load_context: LoadContext): + data = get_instance(state["content"]["data"], load_context) + mask = get_instance(state["content"]["mask"], load_context) return np.ma.MaskedArray(data, mask) -def random_state_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: - content = get_state(obj.get_state(legacy=False), save_state) +def random_state_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: + content = get_state(obj.get_state(legacy=False), save_context) res = { "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), @@ -108,15 +108,15 @@ def random_state_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def random_state_get_instance(state, load_state: LoadState): +def random_state_get_instance(state, load_context: LoadContext): cls = _import_obj(state["__module__"], state["__class__"]) random_state = cls() - content = get_instance(state["content"], load_state) + content = get_instance(state["content"], load_context) random_state.set_state(content) return random_state -def random_generator_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: +def random_generator_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: bit_generator_state = obj.bit_generator.state res = { "__class__": obj.__class__.__name__, @@ -127,7 +127,7 @@ def random_generator_get_state(obj: Any, save_state: SaveState) -> dict[str, Any return res -def random_generator_get_instance(state, load_state: LoadState): +def random_generator_get_instance(state, load_context: LoadContext): # first restore the state of the bit generator bit_generator_state = state["content"]["bit_generator"] bit_generator = _import_obj("numpy.random", bit_generator_state["bit_generator"])() @@ -142,7 +142,7 @@ def random_generator_get_instance(state, load_state: LoadState): # For numpy.ufunc we need to get the type from the type's module, but for other # functions we get it from objet's module directly. Therefore sett a especial # get_state method for them here. The load is the same as other functions. -def ufunc_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: +def ufunc_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, # ufunc "__module__": get_module(type(obj)), # numpy @@ -155,7 +155,7 @@ def ufunc_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def dtype_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: +def dtype_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: # we use numpy's internal save mechanism to store the dtype by # saving/loading an empty array with that dtype. tmp: np.typing.NDArray = np.ndarray(0, dtype=obj) @@ -163,15 +163,15 @@ def dtype_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: "__class__": "dtype", "__module__": "numpy", "__loader__": "dtype_get_instance", - "content": ndarray_get_state(tmp, save_state), + "content": ndarray_get_state(tmp, save_context), } return res -def dtype_get_instance(state, load_state: LoadState): +def dtype_get_instance(state, load_context: LoadContext): # we use numpy's internal save mechanism to store the dtype by # saving/loading an empty array with that dtype. - tmp = ndarray_get_instance(state["content"], load_state) + tmp = ndarray_get_instance(state["content"], load_context) return tmp.dtype diff --git a/skops/io/_persist.py b/skops/io/_persist.py index c297acda..79d9d34e 100644 --- a/skops/io/_persist.py +++ b/skops/io/_persist.py @@ -8,7 +8,7 @@ import skops from ._dispatch import GET_INSTANCE_MAPPING, get_instance -from ._utils import LoadState, SaveState, _get_state, get_state +from ._utils import LoadContext, SaveContext, _get_state, get_state # We load the dispatch functions from the corresponding modules and register # them. @@ -26,11 +26,11 @@ def _save(obj): buffer = io.BytesIO() with ZipFile(buffer, "w") as zip_file: - save_state = SaveState(zip_file=zip_file) - state = get_state(obj, save_state) - save_state.clear_memo() + save_context = SaveContext(zip_file=zip_file) + state = get_state(obj, save_context) + save_context.clear_memo() - state["protocol"] = save_state.protocol + state["protocol"] = save_context.protocol state["_skops_version"] = skops.__version__ zip_file.writestr("schema.json", json.dumps(state, indent=2)) @@ -114,8 +114,8 @@ def load(file): """ with ZipFile(file, "r") as input_zip: schema = json.loads(input_zip.read("schema.json")) - load_state = LoadState(src=input_zip) - instance = get_instance(schema, load_state=load_state) + load_context = LoadContext(src=input_zip) + instance = get_instance(schema, load_context=load_context) return instance @@ -141,6 +141,6 @@ def loads(data): with ZipFile(io.BytesIO(data), "r") as input_zip: schema = json.loads(input_zip.read("schema.json")) - load_state = LoadState(src=input_zip) - instance = get_instance(schema, load_state=load_state) + load_context = LoadContext(src=input_zip) + instance = get_instance(schema, load_context=load_context) return instance diff --git a/skops/io/_scipy.py b/skops/io/_scipy.py index 166a1717..95c2202e 100644 --- a/skops/io/_scipy.py +++ b/skops/io/_scipy.py @@ -5,10 +5,10 @@ from scipy.sparse import load_npz, save_npz, spmatrix -from ._utils import LoadState, SaveState, get_module +from ._utils import LoadContext, SaveContext, get_module -def sparse_matrix_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: +def sparse_matrix_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), @@ -21,17 +21,17 @@ def sparse_matrix_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: # the object id) already exists. If it does, there is no need to # save the object again. Memoizitation is necessary since for # ephemeral objects, the same id might otherwise be reused. - obj_id = save_state.memoize(obj) + obj_id = save_context.memoize(obj) f_name = f"{obj_id}.npz" - if f_name not in save_state.zip_file.namelist(): - save_state.zip_file.writestr(f_name, data_buffer.getbuffer()) + if f_name not in save_context.zip_file.namelist(): + save_context.zip_file.writestr(f_name, data_buffer.getbuffer()) res["type"] = "scipy" res["file"] = f_name return res -def sparse_matrix_get_instance(state, load_state: LoadState): +def sparse_matrix_get_instance(state, load_context: LoadContext): if state["type"] != "scipy": raise TypeError( f"Cannot load object of type {state['__module__']}.{state['__class__']}" @@ -39,7 +39,7 @@ def sparse_matrix_get_instance(state, load_state: LoadState): # scipy load_npz uses numpy.save with allow_pickle=False under the hood, so # we're safe using it - val = load_npz(io.BytesIO(load_state.src.read(state["file"]))) + val = load_npz(io.BytesIO(load_context.src.read(state["file"]))) return val diff --git a/skops/io/_sklearn.py b/skops/io/_sklearn.py index 89fc47de..f135cf8b 100644 --- a/skops/io/_sklearn.py +++ b/skops/io/_sklearn.py @@ -24,7 +24,7 @@ from ._dispatch import get_instance from ._general import dict_get_instance, dict_get_state, unsupported_get_state -from ._utils import SaveState, get_module, get_state, gettype +from ._utils import SaveContext, get_module, get_state, gettype from .exceptions import UnsupportedTypeException ALLOWED_SGD_LOSSES = { @@ -41,7 +41,7 @@ UNSUPPORTED_TYPES = {Birch} -def reduce_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: +def reduce_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: # This method is for objects for which we have to use the __reduce__ # method to get the state. res = { @@ -65,7 +65,7 @@ def reduce_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: # As a good example, this makes Tree object to be serializable. reduce = obj.__reduce__() res["__reduce__"] = {} - res["__reduce__"]["args"] = get_state(reduce[1], save_state) + res["__reduce__"]["args"] = get_state(reduce[1], save_context) if len(reduce) == 3: # reduce includes what's needed for __getstate__ and we don't need to @@ -83,16 +83,16 @@ def reduce_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: f"Objects of type {res['__class__']} not supported yet" ) - res["content"] = get_state(attrs, save_state) + res["content"] = get_state(attrs, save_context) return res -def reduce_get_instance(state, load_state, constructor): +def reduce_get_instance(state, load_context, constructor): reduce = state["__reduce__"] - args = get_instance(reduce["args"], load_state) + args = get_instance(reduce["args"], load_context) instance = constructor(*args) - attrs = get_instance(state["content"], load_state) + attrs = get_instance(state["content"], load_context) if not attrs: # nothing more to do return instance @@ -110,32 +110,32 @@ def reduce_get_instance(state, load_state, constructor): return instance -def tree_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: - state = reduce_get_state(obj, save_state) +def tree_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: + state = reduce_get_state(obj, save_context) state["__loader__"] = "tree_get_instance" return state -def tree_get_instance(state, load_state): - return reduce_get_instance(state, load_state, constructor=Tree) +def tree_get_instance(state, load_context): + return reduce_get_instance(state, load_context, constructor=Tree) -def sgd_loss_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: - state = reduce_get_state(obj, save_state) +def sgd_loss_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: + state = reduce_get_state(obj, save_context) state["__loader__"] = "sgd_loss_get_instance" return state -def sgd_loss_get_instance(state, load_state): +def sgd_loss_get_instance(state, load_context): cls = gettype(state) if cls not in ALLOWED_SGD_LOSSES: raise UnsupportedTypeException(f"Expected LossFunction, got {cls}") - return reduce_get_instance(state, load_state, constructor=cls) + return reduce_get_instance(state, load_context, constructor=cls) # TODO: remove once support for sklearn<1.2 is dropped. def _DictWithDeprecatedKeys_get_state( - obj: Any, save_state: SaveState + obj: Any, save_context: SaveContext ) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, @@ -143,20 +143,20 @@ def _DictWithDeprecatedKeys_get_state( "__loader__": "_DictWithDeprecatedKeys_get_instance", } content = {} - content["main"] = dict_get_state(obj, save_state) + content["main"] = dict_get_state(obj, save_context) content["_deprecated_key_to_new_key"] = dict_get_state( - obj._deprecated_key_to_new_key, save_state + obj._deprecated_key_to_new_key, save_context ) res["content"] = content return res # TODO: remove once support for sklearn<1.2 is dropped. -def _DictWithDeprecatedKeys_get_instance(state, load_state): +def _DictWithDeprecatedKeys_get_instance(state, load_context): # _DictWithDeprecatedKeys is just a wrapper for dict - content = dict_get_instance(state["content"]["main"], load_state) + content = dict_get_instance(state["content"]["main"], load_context) deprecated_key_to_new_key = dict_get_instance( - state["content"]["_deprecated_key_to_new_key"], load_state + state["content"]["_deprecated_key_to_new_key"], load_context ) res = _DictWithDeprecatedKeys(**content) res._deprecated_key_to_new_key = deprecated_key_to_new_key diff --git a/skops/io/_utils.py b/skops/io/_utils.py index 20e7be35..70d2126c 100644 --- a/skops/io/_utils.py +++ b/skops/io/_utils.py @@ -88,7 +88,7 @@ def get_module(obj): @dataclass(frozen=True) -class SaveState: +class SaveContext: """State required for saving the objects This state is passed to each ``get_state_*`` function. @@ -123,7 +123,7 @@ def clear_memo(self) -> None: @dataclass(frozen=True) -class LoadState: +class LoadContext: """State required for loading an object This state is passed to each ``get_instance_*`` function. @@ -146,21 +146,21 @@ def get_instance(self, id: int) -> Any: @singledispatch -def _get_state(obj, save_state): +def _get_state(obj, save_context): # This function should never be called directly. Instead, it is used to # dispatch to the correct implementation of get_state for the given type of # its first argument. raise TypeError(f"Getting the state of type {type(obj)} is not supported yet") -def get_state(value, save_state): +def get_state(value, save_context): # This is a helper function to try to get the state of an object. If it # fails with `get_state`, we try with json.dumps, if that fails, we raise # the original error alongside the json error. - __id__ = save_state.memoize(obj=value) + __id__ = save_context.memoize(obj=value) try: - res = _get_state(value, save_state) + res = _get_state(value, save_context) except TypeError as e1: try: res = json.dumps(value) diff --git a/skops/io/tests/test_persist.py b/skops/io/tests/test_persist.py index e1ab8da3..0bbd7325 100644 --- a/skops/io/tests/test_persist.py +++ b/skops/io/tests/test_persist.py @@ -56,7 +56,7 @@ from skops.io import dump, dumps, load, loads from skops.io._dispatch import GET_INSTANCE_MAPPING, get_instance from skops.io._sklearn import UNSUPPORTED_TYPES -from skops.io._utils import LoadState, SaveState, _get_state, get_state +from skops.io._utils import LoadContext, SaveContext, _get_state, get_state from skops.io.exceptions import UnsupportedTypeException # Default settings for X @@ -79,16 +79,17 @@ def debug_get_state(func): # Check consistency of argument names, output type, and that the output, # if a dict, has certain keys, or if not a dict, is a primitive type. signature = inspect.signature(func) - assert list(signature.parameters.keys()) == ["obj", "save_state"] + assert list(signature.parameters.keys()) == ["obj", "save_context"] @wraps(func) - def wrapper(obj, save_state): - result = func(obj, save_state) + def wrapper(obj, save_context): + # NB: __id__ set in main 'get_state' func, so no check here + result = func(obj, save_context) assert "__class__" in result assert "__module__" in result assert "__loader__" in result - assert "__id__" in result + return result return wrapper @@ -96,17 +97,17 @@ def wrapper(obj, save_state): def debug_get_instance(func): # check consistency of argument names and input type signature = inspect.signature(func) - assert list(signature.parameters.keys()) == ["state", "load_state"] + assert list(signature.parameters.keys()) == ["state", "load_context"] @wraps(func) - def wrapper(state, load_state): + def wrapper(state, load_context): assert "__class__" in state assert "__module__" in state assert "__loader__" in state assert "__id__" in state - assert isinstance(load_state, LoadState) + assert isinstance(load_context, LoadContext) - result = func(state, load_state) + result = func(state, load_context) return result return wrapper @@ -787,12 +788,12 @@ def test_loads_from_str(): def test_get_instance_unknown_type_error_msg(): val = ("hi", [123]) - save_state = SaveState(None) - state = get_state(val, save_state) + save_context = SaveContext(None) + state = get_state(val, save_context) state["__loader__"] = "this_get_instance_does_not_exist" msg = "Can't find loader this_get_instance_does_not_exist for type builtins.tuple." with pytest.raises(TypeError, match=msg): - get_instance(state, LoadState(None)) + get_instance(state, LoadContext(None)) class _BoundMethodHolder: From df3848a7dfac3414e742edf2ea3046de456ed162 Mon Sep 17 00:00:00 2001 From: = Date: Wed, 9 Nov 2022 18:37:40 +0000 Subject: [PATCH 13/14] Update docstrings to use context, not state --- skops/io/_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/skops/io/_utils.py b/skops/io/_utils.py index 70d2126c..f506bb49 100644 --- a/skops/io/_utils.py +++ b/skops/io/_utils.py @@ -89,9 +89,9 @@ def get_module(obj): @dataclass(frozen=True) class SaveContext: - """State required for saving the objects + """Context required for saving the objects - This state is passed to each ``get_state_*`` function. + This context is passed to each ``get_state_*`` function. Parameters ---------- @@ -124,9 +124,9 @@ def clear_memo(self) -> None: @dataclass(frozen=True) class LoadContext: - """State required for loading an object + """Context required for loading an object - This state is passed to each ``get_instance_*`` function. + This context is passed to each ``get_instance_*`` function. Parameters ---------- From 5566f0db2c46e27bfe6da030eb8f6d1124330394 Mon Sep 17 00:00:00 2001 From: Erik Aho Date: Fri, 11 Nov 2022 19:39:29 +0000 Subject: [PATCH 14/14] Remove newline in docstring --- skops/io/_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/skops/io/_utils.py b/skops/io/_utils.py index f506bb49..fa46fd75 100644 --- a/skops/io/_utils.py +++ b/skops/io/_utils.py @@ -132,7 +132,6 @@ class LoadContext: ---------- src: zipfile.ZipFile The zip file the target object is saved in - """ src: ZipFile