diff --git a/skops/io/_dispatch.py b/skops/io/_dispatch.py index e0ae9a96..02fd3bd2 100644 --- a/skops/io/_dispatch.py +++ b/skops/io/_dispatch.py @@ -2,19 +2,34 @@ import json +from skops.io._utils import LoadContext + GET_INSTANCE_MAPPING = {} # type: ignore -def get_instance(state, src): +def get_instance(state, load_context: LoadContext): """Create instance based on the state, using json if possible""" + + saved_id = state.get("__id__") + if saved_id in load_context.memo: + # an instance has already been loaded, just return the loaded instance + return load_context.get_instance(saved_id) + if state.get("is_json"): - return json.loads(state["content"]) - - try: - get_instance_func = GET_INSTANCE_MAPPING[state["__loader__"]] - except KeyError: - type_name = f"{state['__module__']}.{state['__class__']}" - raise TypeError( - f" Can't find loader {state['__loader__']} for type {type_name}." - ) - return get_instance_func(state, src) + loaded_obj = json.loads(state["content"]) + else: + try: + get_instance_func = GET_INSTANCE_MAPPING[state["__loader__"]] + except KeyError: + type_name = f"{state['__module__']}.{state['__class__']}" + raise TypeError( + f" Can't find loader {state['__loader__']} for type {type_name}." + ) + + loaded_obj = get_instance_func(state, load_context) + + # hold reference to obj in case same instance encountered again in save state + if saved_id: + load_context.memoize(loaded_obj, saved_id) + + return loaded_obj diff --git a/skops/io/_general.py b/skops/io/_general.py index c0594319..bf858ad2 100644 --- a/skops/io/_general.py +++ b/skops/io/_general.py @@ -8,18 +8,25 @@ import numpy as np from ._dispatch import get_instance -from ._utils import SaveState, _import_obj, get_module, get_state, gettype +from ._utils import ( + LoadContext, + SaveContext, + _import_obj, + get_module, + get_state, + gettype, +) from .exceptions import UnsupportedTypeException -def dict_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: +def dict_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), "__loader__": "dict_get_instance", } - key_types = get_state([type(key) for key in obj.keys()], save_state) + key_types = get_state([type(key) for key in obj.keys()], save_context) content = {} for key, value in obj.items(): if isinstance(value, property): @@ -27,21 +34,21 @@ def dict_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: if np.isscalar(key) and hasattr(key, "item"): # convert numpy value to python object key = key.item() # type: ignore - content[key] = get_state(value, save_state) + content[key] = get_state(value, save_context) res["content"] = content res["key_types"] = key_types return res -def dict_get_instance(state, src): +def dict_get_instance(state, load_context: LoadContext): content = gettype(state)() - key_types = get_instance(state["key_types"], src) + key_types = get_instance(state["key_types"], load_context) for k_type, item in zip(key_types, state["content"].items()): - content[k_type(item[0])] = get_instance(item[1], src) + content[k_type(item[0])] = get_instance(item[1], load_context) return content -def list_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: +def list_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), @@ -49,30 +56,30 @@ def list_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: } content = [] for value in obj: - content.append(get_state(value, save_state)) + content.append(get_state(value, save_context)) res["content"] = content return res -def list_get_instance(state, src): +def list_get_instance(state, load_context: LoadContext): content = gettype(state)() for value in state["content"]: - content.append(get_instance(value, src)) + content.append(get_instance(value, load_context)) return content -def tuple_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: +def tuple_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), "__loader__": "tuple_get_instance", } - content = tuple(get_state(value, save_state) for value in obj) + content = tuple(get_state(value, save_context) for value in obj) res["content"] = content return res -def tuple_get_instance(state, src): +def tuple_get_instance(state, load_context: LoadContext): # Returns a tuple or a namedtuple instance. def isnamedtuple(t): # This is needed since namedtuples need to have the args when @@ -86,14 +93,14 @@ def isnamedtuple(t): return all(type(n) == str for n in f) cls = gettype(state) - content = tuple(get_instance(value, src) for value in state["content"]) + content = tuple(get_instance(value, load_context) for value in state["content"]) if isnamedtuple(cls): return cls(*content) return content -def function_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: +def function_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, "__module__": get_module(obj), @@ -106,39 +113,39 @@ def function_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def function_get_instance(state, src): +def function_get_instance(state, load_context: LoadContext): loaded = _import_obj(state["content"]["module_path"], state["content"]["function"]) return loaded -def partial_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: +def partial_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: _, _, (func, args, kwds, namespace) = obj.__reduce__() res = { "__class__": "partial", # don't allow any subclass "__module__": get_module(type(obj)), "__loader__": "partial_get_instance", "content": { - "func": get_state(func, save_state), - "args": get_state(args, save_state), - "kwds": get_state(kwds, save_state), - "namespace": get_state(namespace, save_state), + "func": get_state(func, save_context), + "args": get_state(args, save_context), + "kwds": get_state(kwds, save_context), + "namespace": get_state(namespace, save_context), }, } return res -def partial_get_instance(state, src): +def partial_get_instance(state, load_context: LoadContext): content = state["content"] - func = get_instance(content["func"], src) - args = get_instance(content["args"], src) - kwds = get_instance(content["kwds"], src) - namespace = get_instance(content["namespace"], src) + func = get_instance(content["func"], load_context) + args = get_instance(content["args"], load_context) + kwds = get_instance(content["kwds"], load_context) + namespace = get_instance(content["namespace"], load_context) instance = partial(func, *args, **kwds) # always use partial, not a subclass - instance.__setstate__((func, args, kwds, namespace)) + instance.__setstate__((func, args, kwds, namespace)) # type: ignore return instance -def type_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: +def type_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: # To serialize a type, we first need to set the metadata to tell that it's # a type, then store the type's info itself in the content field. res = { @@ -153,12 +160,12 @@ def type_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def type_get_instance(state, src): +def type_get_instance(state, load_context: LoadContext): loaded = _import_obj(state["content"]["__module__"], state["content"]["__class__"]) return loaded -def slice_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: +def slice_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), @@ -172,14 +179,14 @@ def slice_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def slice_get_instance(state, src): +def slice_get_instance(state, load_context: LoadContext): start = state["content"]["start"] stop = state["content"]["stop"] step = state["content"]["step"] return slice(start, stop, step) -def object_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: +def object_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: # This method is for objects which can either be persisted with json, or # the ones for which we can get/set attributes through # __getstate__/__setstate__ or reading/writing to __dict__. @@ -211,14 +218,14 @@ def object_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: else: return res - content = get_state(attrs, save_state) + content = get_state(attrs, save_context) # it's sufficient to store the "content" because we know that this dict can # only have str type keys res["content"] = content return res -def object_get_instance(state, src): +def object_get_instance(state, load_context: LoadContext): if state.get("is_json", False): return json.loads(state["content"]) @@ -233,7 +240,7 @@ def object_get_instance(state, src): if not content: # nothing more to do return instance - attrs = get_instance(content, src) + attrs = get_instance(content, load_context) if hasattr(instance, "__setstate__"): instance.__setstate__(attrs) else: @@ -242,7 +249,7 @@ def object_get_instance(state, src): return instance -def method_get_state(obj: Any, save_state: SaveState): +def method_get_state(obj: Any, save_context: SaveContext): # This method is used to persist bound methods, which are # dependent on a specific instance of an object. # It stores the state of the object the method is bound to, @@ -253,20 +260,19 @@ def method_get_state(obj: Any, save_state: SaveState): "__loader__": "method_get_instance", "content": { "func": obj.__func__.__name__, - "obj": get_state(obj.__self__, save_state), + "obj": get_state(obj.__self__, save_context), }, } - return res -def method_get_instance(state, src): - loaded_obj = object_get_instance(state["content"]["obj"], src) +def method_get_instance(state, load_context: LoadContext): + loaded_obj = get_instance(state["content"]["obj"], load_context) method = getattr(loaded_obj, state["content"]["func"]) return method -def unsupported_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: +def unsupported_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: raise UnsupportedTypeException(obj) diff --git a/skops/io/_numpy.py b/skops/io/_numpy.py index 4d1d5b98..f8414f10 100644 --- a/skops/io/_numpy.py +++ b/skops/io/_numpy.py @@ -7,11 +7,11 @@ from ._dispatch import get_instance from ._general import function_get_instance -from ._utils import SaveState, _import_obj, get_module, get_state +from ._utils import LoadContext, SaveContext, _import_obj, get_module, get_state from .exceptions import UnsupportedTypeException -def ndarray_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: +def ndarray_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), @@ -23,10 +23,10 @@ def ndarray_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: # allow_pickle=False, therefore we convert them to a list and # recursively call get_state on it. if obj.dtype == object: - obj_serialized = get_state(obj.tolist(), save_state) + obj_serialized = get_state(obj.tolist(), save_context) res["content"] = obj_serialized["content"] res["type"] = "json" - res["shape"] = get_state(obj.shape, save_state) + res["shape"] = get_state(obj.shape, save_context) else: data_buffer = io.BytesIO() np.save(data_buffer, obj) @@ -34,10 +34,10 @@ def ndarray_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: # the object id) already exists. If it does, there is no need to # save the object again. Memoizitation is necessary since for # ephemeral objects, the same id might otherwise be reused. - obj_id = save_state.memoize(obj) + obj_id = save_context.memoize(obj) f_name = f"{obj_id}.npy" - if f_name not in save_state.zip_file.namelist(): - save_state.zip_file.writestr(f_name, data_buffer.getbuffer()) + if f_name not in save_context.zip_file.namelist(): + save_context.zip_file.writestr(f_name, data_buffer.getbuffer()) res.update(type="numpy", file=f_name) except ValueError: # Couldn't save the numpy array with either method @@ -50,10 +50,12 @@ def ndarray_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def ndarray_get_instance(state, src): +def ndarray_get_instance(state, load_context: LoadContext): # Dealing with a regular numpy array, where dtype != object if state["type"] == "numpy": - val = np.load(io.BytesIO(src.read(state["file"])), allow_pickle=False) + val = np.load( + io.BytesIO(load_context.src.read(state["file"])), allow_pickle=False + ) # Coerce type, because it may not be conserved by np.save/load. E.g. a # scalar will be loaded as a 0-dim array. if state["__class__"] != "ndarray": @@ -63,8 +65,8 @@ def ndarray_get_instance(state, src): # We explicitly set the dtype to "O" since we only save object arrays in # json. - shape = get_instance(state["shape"], src) - tmp = [get_instance(s, src) for s in state["content"]] + shape = get_instance(state["shape"], load_context) + tmp = [get_instance(s, load_context) for s in state["content"]] # TODO: this is a hack to get the correct shape of the array. We should # find _a better way_ to do this. if len(shape) == 1: @@ -76,27 +78,27 @@ def ndarray_get_instance(state, src): return val -def maskedarray_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: +def maskedarray_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), "__loader__": "maskedarray_get_instance", "content": { - "data": get_state(obj.data, save_state), - "mask": get_state(obj.mask, save_state), + "data": get_state(obj.data, save_context), + "mask": get_state(obj.mask, save_context), }, } return res -def maskedarray_get_instance(state, src): - data = get_instance(state["content"]["data"], src) - mask = get_instance(state["content"]["mask"], src) +def maskedarray_get_instance(state, load_context: LoadContext): + data = get_instance(state["content"]["data"], load_context) + mask = get_instance(state["content"]["mask"], load_context) return np.ma.MaskedArray(data, mask) -def random_state_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: - content = get_state(obj.get_state(legacy=False), save_state) +def random_state_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: + content = get_state(obj.get_state(legacy=False), save_context) res = { "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), @@ -106,15 +108,15 @@ def random_state_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def random_state_get_instance(state, src): +def random_state_get_instance(state, load_context: LoadContext): cls = _import_obj(state["__module__"], state["__class__"]) random_state = cls() - content = get_instance(state["content"], src) + content = get_instance(state["content"], load_context) random_state.set_state(content) return random_state -def random_generator_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: +def random_generator_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: bit_generator_state = obj.bit_generator.state res = { "__class__": obj.__class__.__name__, @@ -125,7 +127,7 @@ def random_generator_get_state(obj: Any, save_state: SaveState) -> dict[str, Any return res -def random_generator_get_instance(state, src): +def random_generator_get_instance(state, load_context: LoadContext): # first restore the state of the bit generator bit_generator_state = state["content"]["bit_generator"] bit_generator = _import_obj("numpy.random", bit_generator_state["bit_generator"])() @@ -140,7 +142,7 @@ def random_generator_get_instance(state, src): # For numpy.ufunc we need to get the type from the type's module, but for other # functions we get it from objet's module directly. Therefore sett a especial # get_state method for them here. The load is the same as other functions. -def ufunc_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: +def ufunc_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, # ufunc "__module__": get_module(type(obj)), # numpy @@ -153,7 +155,7 @@ def ufunc_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def dtype_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: +def dtype_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: # we use numpy's internal save mechanism to store the dtype by # saving/loading an empty array with that dtype. tmp: np.typing.NDArray = np.ndarray(0, dtype=obj) @@ -161,15 +163,15 @@ def dtype_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: "__class__": "dtype", "__module__": "numpy", "__loader__": "dtype_get_instance", - "content": ndarray_get_state(tmp, save_state), + "content": ndarray_get_state(tmp, save_context), } return res -def dtype_get_instance(state, src): +def dtype_get_instance(state, load_context: LoadContext): # we use numpy's internal save mechanism to store the dtype by # saving/loading an empty array with that dtype. - tmp = ndarray_get_instance(state["content"], src) + tmp = ndarray_get_instance(state["content"], load_context) return tmp.dtype diff --git a/skops/io/_persist.py b/skops/io/_persist.py index 9144da08..79d9d34e 100644 --- a/skops/io/_persist.py +++ b/skops/io/_persist.py @@ -8,7 +8,7 @@ import skops from ._dispatch import GET_INSTANCE_MAPPING, get_instance -from ._utils import SaveState, _get_state, get_state +from ._utils import LoadContext, SaveContext, _get_state, get_state # We load the dispatch functions from the corresponding modules and register # them. @@ -26,13 +26,12 @@ def _save(obj): buffer = io.BytesIO() with ZipFile(buffer, "w") as zip_file: - save_state = SaveState(zip_file=zip_file) - state = get_state(obj, save_state) - save_state.clear_memo() + save_context = SaveContext(zip_file=zip_file) + state = get_state(obj, save_context) + save_context.clear_memo() - state["protocol"] = save_state.protocol + state["protocol"] = save_context.protocol state["_skops_version"] = skops.__version__ - zip_file.writestr("schema.json", json.dumps(state, indent=2)) return buffer @@ -114,8 +113,9 @@ def load(file): """ with ZipFile(file, "r") as input_zip: - schema = input_zip.read("schema.json") - instance = get_instance(json.loads(schema), input_zip) + schema = json.loads(input_zip.read("schema.json")) + load_context = LoadContext(src=input_zip) + instance = get_instance(schema, load_context=load_context) return instance @@ -139,7 +139,8 @@ def loads(data): if isinstance(data, str): raise TypeError("Can't load skops format from string, pass bytes") - with ZipFile(io.BytesIO(data), "r") as zip_file: - schema = json.loads(zip_file.read("schema.json")) - instance = get_instance(schema, src=zip_file) + with ZipFile(io.BytesIO(data), "r") as input_zip: + schema = json.loads(input_zip.read("schema.json")) + load_context = LoadContext(src=input_zip) + instance = get_instance(schema, load_context=load_context) return instance diff --git a/skops/io/_scipy.py b/skops/io/_scipy.py index 305d30dc..95c2202e 100644 --- a/skops/io/_scipy.py +++ b/skops/io/_scipy.py @@ -5,10 +5,10 @@ from scipy.sparse import load_npz, save_npz, spmatrix -from ._utils import SaveState, get_module +from ._utils import LoadContext, SaveContext, get_module -def sparse_matrix_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: +def sparse_matrix_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), @@ -21,17 +21,17 @@ def sparse_matrix_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: # the object id) already exists. If it does, there is no need to # save the object again. Memoizitation is necessary since for # ephemeral objects, the same id might otherwise be reused. - obj_id = save_state.memoize(obj) + obj_id = save_context.memoize(obj) f_name = f"{obj_id}.npz" - if f_name not in save_state.zip_file.namelist(): - save_state.zip_file.writestr(f_name, data_buffer.getbuffer()) + if f_name not in save_context.zip_file.namelist(): + save_context.zip_file.writestr(f_name, data_buffer.getbuffer()) res["type"] = "scipy" res["file"] = f_name return res -def sparse_matrix_get_instance(state, src): +def sparse_matrix_get_instance(state, load_context: LoadContext): if state["type"] != "scipy": raise TypeError( f"Cannot load object of type {state['__module__']}.{state['__class__']}" @@ -39,7 +39,7 @@ def sparse_matrix_get_instance(state, src): # scipy load_npz uses numpy.save with allow_pickle=False under the hood, so # we're safe using it - val = load_npz(io.BytesIO(src.read(state["file"]))) + val = load_npz(io.BytesIO(load_context.src.read(state["file"]))) return val diff --git a/skops/io/_sklearn.py b/skops/io/_sklearn.py index 6cde5463..f135cf8b 100644 --- a/skops/io/_sklearn.py +++ b/skops/io/_sklearn.py @@ -24,7 +24,7 @@ from ._dispatch import get_instance from ._general import dict_get_instance, dict_get_state, unsupported_get_state -from ._utils import SaveState, get_module, get_state, gettype +from ._utils import SaveContext, get_module, get_state, gettype from .exceptions import UnsupportedTypeException ALLOWED_SGD_LOSSES = { @@ -41,7 +41,7 @@ UNSUPPORTED_TYPES = {Birch} -def reduce_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: +def reduce_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: # This method is for objects for which we have to use the __reduce__ # method to get the state. res = { @@ -65,7 +65,7 @@ def reduce_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: # As a good example, this makes Tree object to be serializable. reduce = obj.__reduce__() res["__reduce__"] = {} - res["__reduce__"]["args"] = get_state(reduce[1], save_state) + res["__reduce__"]["args"] = get_state(reduce[1], save_context) if len(reduce) == 3: # reduce includes what's needed for __getstate__ and we don't need to @@ -83,16 +83,16 @@ def reduce_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: f"Objects of type {res['__class__']} not supported yet" ) - res["content"] = get_state(attrs, save_state) + res["content"] = get_state(attrs, save_context) return res -def reduce_get_instance(state, src, constructor): +def reduce_get_instance(state, load_context, constructor): reduce = state["__reduce__"] - args = get_instance(reduce["args"], src) + args = get_instance(reduce["args"], load_context) instance = constructor(*args) - attrs = get_instance(state["content"], src) + attrs = get_instance(state["content"], load_context) if not attrs: # nothing more to do return instance @@ -110,32 +110,32 @@ def reduce_get_instance(state, src, constructor): return instance -def tree_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: - state = reduce_get_state(obj, save_state) +def tree_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: + state = reduce_get_state(obj, save_context) state["__loader__"] = "tree_get_instance" return state -def tree_get_instance(state, src): - return reduce_get_instance(state, src, constructor=Tree) +def tree_get_instance(state, load_context): + return reduce_get_instance(state, load_context, constructor=Tree) -def sgd_loss_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: - state = reduce_get_state(obj, save_state) +def sgd_loss_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: + state = reduce_get_state(obj, save_context) state["__loader__"] = "sgd_loss_get_instance" return state -def sgd_loss_get_instance(state, src): +def sgd_loss_get_instance(state, load_context): cls = gettype(state) if cls not in ALLOWED_SGD_LOSSES: raise UnsupportedTypeException(f"Expected LossFunction, got {cls}") - return reduce_get_instance(state, src, constructor=cls) + return reduce_get_instance(state, load_context, constructor=cls) # TODO: remove once support for sklearn<1.2 is dropped. def _DictWithDeprecatedKeys_get_state( - obj: Any, save_state: SaveState + obj: Any, save_context: SaveContext ) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, @@ -143,20 +143,20 @@ def _DictWithDeprecatedKeys_get_state( "__loader__": "_DictWithDeprecatedKeys_get_instance", } content = {} - content["main"] = dict_get_state(obj, save_state) + content["main"] = dict_get_state(obj, save_context) content["_deprecated_key_to_new_key"] = dict_get_state( - obj._deprecated_key_to_new_key, save_state + obj._deprecated_key_to_new_key, save_context ) res["content"] = content return res # TODO: remove once support for sklearn<1.2 is dropped. -def _DictWithDeprecatedKeys_get_instance(state, src): +def _DictWithDeprecatedKeys_get_instance(state, load_context): # _DictWithDeprecatedKeys is just a wrapper for dict - content = dict_get_instance(state["content"]["main"], src) + content = dict_get_instance(state["content"]["main"], load_context) deprecated_key_to_new_key = dict_get_instance( - state["content"]["_deprecated_key_to_new_key"], src + state["content"]["_deprecated_key_to_new_key"], load_context ) res = _DictWithDeprecatedKeys(**content) res._deprecated_key_to_new_key = deprecated_key_to_new_key diff --git a/skops/io/_utils.py b/skops/io/_utils.py index 08d235e4..fa46fd75 100644 --- a/skops/io/_utils.py +++ b/skops/io/_utils.py @@ -88,10 +88,10 @@ def get_module(obj): @dataclass(frozen=True) -class SaveState: - """State required for saving the objects +class SaveContext: + """Context required for saving the objects - This state is passed to each ``get_state_*`` function. + This context is passed to each ``get_state_*`` function. Parameters ---------- @@ -122,22 +122,49 @@ def clear_memo(self) -> None: self.memo.clear() +@dataclass(frozen=True) +class LoadContext: + """Context required for loading an object + + This context is passed to each ``get_instance_*`` function. + + Parameters + ---------- + src: zipfile.ZipFile + The zip file the target object is saved in + """ + + src: ZipFile + memo: dict[int, Any] = field(default_factory=dict) + + def memoize(self, obj: Any, id: int) -> None: + self.memo[id] = obj + + def get_instance(self, id: int) -> Any: + return self.memo.get(id) + + @singledispatch -def _get_state(obj, save_state): +def _get_state(obj, save_context): # This function should never be called directly. Instead, it is used to # dispatch to the correct implementation of get_state for the given type of # its first argument. raise TypeError(f"Getting the state of type {type(obj)} is not supported yet") -def get_state(value, save_state): +def get_state(value, save_context): # This is a helper function to try to get the state of an object. If it # fails with `get_state`, we try with json.dumps, if that fails, we raise # the original error alongside the json error. + __id__ = save_context.memoize(obj=value) + try: - return _get_state(value, save_state) + res = _get_state(value, save_context) except TypeError as e1: try: - return json.dumps(value) + res = json.dumps(value) except Exception as e2: raise e1 from e2 + + res["__id__"] = __id__ + return res diff --git a/skops/io/tests/test_persist.py b/skops/io/tests/test_persist.py index b999487b..0bbd7325 100644 --- a/skops/io/tests/test_persist.py +++ b/skops/io/tests/test_persist.py @@ -56,7 +56,7 @@ from skops.io import dump, dumps, load, loads from skops.io._dispatch import GET_INSTANCE_MAPPING, get_instance from skops.io._sklearn import UNSUPPORTED_TYPES -from skops.io._utils import _get_state, get_state +from skops.io._utils import LoadContext, SaveContext, _get_state, get_state from skops.io.exceptions import UnsupportedTypeException # Default settings for X @@ -79,11 +79,12 @@ def debug_get_state(func): # Check consistency of argument names, output type, and that the output, # if a dict, has certain keys, or if not a dict, is a primitive type. signature = inspect.signature(func) - assert list(signature.parameters.keys()) == ["obj", "save_state"] + assert list(signature.parameters.keys()) == ["obj", "save_context"] @wraps(func) - def wrapper(obj, save_state): - result = func(obj, save_state) + def wrapper(obj, save_context): + # NB: __id__ set in main 'get_state' func, so no check here + result = func(obj, save_context) assert "__class__" in result assert "__module__" in result @@ -96,16 +97,17 @@ def wrapper(obj, save_state): def debug_get_instance(func): # check consistency of argument names and input type signature = inspect.signature(func) - assert list(signature.parameters.keys()) == ["state", "src"] + assert list(signature.parameters.keys()) == ["state", "load_context"] @wraps(func) - def wrapper(state, src): + def wrapper(state, load_context): assert "__class__" in state assert "__module__" in state assert "__loader__" in state - assert isinstance(src, ZipFile) + assert "__id__" in state + assert isinstance(load_context, LoadContext) - result = func(state, src) + result = func(state, load_context) return result return wrapper @@ -785,11 +787,13 @@ def test_loads_from_str(): def test_get_instance_unknown_type_error_msg(): - state = get_state(("hi", [123]), None) + val = ("hi", [123]) + save_context = SaveContext(None) + state = get_state(val, save_context) state["__loader__"] = "this_get_instance_does_not_exist" msg = "Can't find loader this_get_instance_does_not_exist for type builtins.tuple." with pytest.raises(TypeError, match=msg): - get_instance(state, None) + get_instance(state, LoadContext(None)) class _BoundMethodHolder: @@ -867,9 +871,6 @@ def test_when_object_is_changed_after_init_works_as_expected(self): self.assert_transformer_persisted_correctly(loaded_transformer, transformer) self.assert_bound_method_holder_persisted_correctly(obj, loaded_obj) - @pytest.mark.xfail( - reason="Can't load an object as a single instance if referenced multiple times" - ) def test_works_when_given_multiple_bound_methods_attached_to_single_instance(self): obj = _BoundMethodHolder(object_state="") @@ -960,3 +961,21 @@ def test_disk_and_memory_are_identical(tmp_path): loaded_memory = loads(dumps(estimator)) assert joblib.hash(loaded_disk) == joblib.hash(loaded_memory) + + +@pytest.mark.parametrize( + "obj", + [ + np.array([1, 2]), + [1, 2, 3], + {1: 1, 2: 2}, + {1, 2, 3}, + "A string", + np.random.RandomState(42), + ], +) +def test_when_given_object_referenced_twice_loads_as_one_object(obj): + an_object = {"obj_1": obj, "obj_2": obj} + persisted_object = loads(dumps(an_object)) + + assert persisted_object["obj_1"] is persisted_object["obj_2"]