diff --git a/skops/io/_dispatch.py b/skops/io/_dispatch.py new file mode 100644 index 00000000..e0ae9a96 --- /dev/null +++ b/skops/io/_dispatch.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +import json + +GET_INSTANCE_MAPPING = {} # type: ignore + + +def get_instance(state, src): + """Create instance based on the state, using json if possible""" + if state.get("is_json"): + return json.loads(state["content"]) + + try: + get_instance_func = GET_INSTANCE_MAPPING[state["__loader__"]] + except KeyError: + type_name = f"{state['__module__']}.{state['__class__']}" + raise TypeError( + f" Can't find loader {state['__loader__']} for type {type_name}." + ) + return get_instance_func(state, src) diff --git a/skops/io/_general.py b/skops/io/_general.py index 88413472..d32d63ae 100644 --- a/skops/io/_general.py +++ b/skops/io/_general.py @@ -7,7 +7,8 @@ import numpy as np -from ._utils import SaveState, _import_obj, get_instance, get_module, get_state, gettype +from ._dispatch import get_instance +from ._utils import SaveState, _import_obj, get_module, get_state, gettype from .exceptions import UnsupportedTypeException @@ -15,6 +16,7 @@ def dict_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), + "__loader__": "dict_get_instance", } key_types = get_state([type(key) for key in obj.keys()], save_state) @@ -43,6 +45,7 @@ def list_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), + "__loader__": "list_get_instance", } content = [] for value in obj: @@ -62,6 +65,7 @@ def tuple_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), + "__loader__": "tuple_get_instance", } content = tuple(get_state(value, save_state) for value in obj) res["content"] = content @@ -93,6 +97,7 @@ def function_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, "__module__": get_module(obj), + "__loader__": "function_get_instance", "content": { "module_path": get_module(obj), "function": obj.__name__, @@ -111,6 +116,7 @@ def partial_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: res = { "__class__": "partial", # don't allow any subclass "__module__": get_module(type(obj)), + "__loader__": "partial_get_instance", "content": { "func": get_state(func, save_state), "args": get_state(args, save_state), @@ -138,6 +144,7 @@ def type_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), + "__loader__": "type_get_instance", "content": { "__class__": obj.__name__, "__module__": get_module(obj), @@ -155,6 +162,7 @@ def slice_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), + "__loader__": "slice_get_instance", "content": { "start": obj.start, "stop": obj.stop, @@ -181,6 +189,7 @@ def object_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return { "__class__": "str", "__module__": "builtins", + "__loader__": "none", "content": obj_str, "is_json": True, } @@ -190,6 +199,7 @@ def object_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), + "__loader__": "object_get_instance", } # __getstate__ takes priority over __dict__, and if non exist, we only save @@ -247,14 +257,14 @@ def unsupported_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: (type, type_get_state), (object, object_get_state), ] -# tuples of type and function that creates the instance of that type -GET_INSTANCE_DISPATCH_FUNCTIONS = [ - (dict, dict_get_instance), - (list, list_get_instance), - (tuple, tuple_get_instance), - (slice, slice_get_instance), - (FunctionType, function_get_instance), - (partial, partial_get_instance), - (type, type_get_instance), - (object, object_get_instance), -] + +GET_INSTANCE_DISPATCH_MAPPING = { + "dict_get_instance": dict_get_instance, + "list_get_instance": list_get_instance, + "tuple_get_instance": tuple_get_instance, + "slice_get_instance": slice_get_instance, + "function_get_instance": function_get_instance, + "partial_get_instance": partial_get_instance, + "type_get_instance": type_get_instance, + "object_get_instance": object_get_instance, +} diff --git a/skops/io/_numpy.py b/skops/io/_numpy.py index cd006c2a..4d1d5b98 100644 --- a/skops/io/_numpy.py +++ b/skops/io/_numpy.py @@ -5,8 +5,9 @@ import numpy as np +from ._dispatch import get_instance from ._general import function_get_instance -from ._utils import SaveState, _import_obj, get_instance, get_module, get_state +from ._utils import SaveState, _import_obj, get_module, get_state from .exceptions import UnsupportedTypeException @@ -14,6 +15,7 @@ def ndarray_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), + "__loader__": "ndarray_get_instance", } try: @@ -78,6 +80,7 @@ def maskedarray_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), + "__loader__": "maskedarray_get_instance", "content": { "data": get_state(obj.data, save_state), "mask": get_state(obj.mask, save_state), @@ -97,6 +100,7 @@ def random_state_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), + "__loader__": "random_state_get_instance", "content": content, } return res @@ -115,6 +119,7 @@ def random_generator_get_state(obj: Any, save_state: SaveState) -> dict[str, Any res = { "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), + "__loader__": "random_generator_get_instance", "content": {"bit_generator": bit_generator_state}, } return res @@ -139,6 +144,7 @@ def ufunc_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, # ufunc "__module__": get_module(type(obj)), # numpy + "__loader__": "function_get_instance", "content": { "module_path": get_module(obj), "function": obj.__name__, @@ -154,6 +160,7 @@ def dtype_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: res = { "__class__": "dtype", "__module__": "numpy", + "__loader__": "dtype_get_instance", "content": ndarray_get_state(tmp, save_state), } return res @@ -177,12 +184,11 @@ def dtype_get_instance(state, src): (np.random.Generator, random_generator_get_state), ] # tuples of type and function that creates the instance of that type -GET_INSTANCE_DISPATCH_FUNCTIONS = [ - (np.generic, ndarray_get_instance), - (np.ndarray, ndarray_get_instance), - (np.ma.MaskedArray, maskedarray_get_instance), - (np.ufunc, function_get_instance), - (np.dtype, dtype_get_instance), - (np.random.RandomState, random_state_get_instance), - (np.random.Generator, random_generator_get_instance), -] +GET_INSTANCE_DISPATCH_MAPPING = { + "ndarray_get_instance": ndarray_get_instance, + "maskedarray_get_instance": maskedarray_get_instance, + "function_get_instance": function_get_instance, + "dtype_get_instance": dtype_get_instance, + "random_state_get_instance": random_state_get_instance, + "random_generator_get_instance": random_generator_get_instance, +} diff --git a/skops/io/_persist.py b/skops/io/_persist.py index a20d83a5..9144da08 100644 --- a/skops/io/_persist.py +++ b/skops/io/_persist.py @@ -7,7 +7,8 @@ import skops -from ._utils import SaveState, _get_instance, _get_state, get_instance, get_state +from ._dispatch import GET_INSTANCE_MAPPING, get_instance +from ._utils import SaveState, _get_state, get_state # We load the dispatch functions from the corresponding modules and register # them. @@ -17,8 +18,8 @@ module = importlib.import_module(module_name, package="skops.io") for cls, method in getattr(module, "GET_STATE_DISPATCH_FUNCTIONS", []): _get_state.register(cls)(method) - for cls, method in getattr(module, "GET_INSTANCE_DISPATCH_FUNCTIONS", []): - _get_instance.register(cls)(method) + # populate the the dict used for dispatching get_instance functions + GET_INSTANCE_MAPPING.update(module.GET_INSTANCE_DISPATCH_MAPPING) def _save(obj): diff --git a/skops/io/_scipy.py b/skops/io/_scipy.py index 11457ec9..305d30dc 100644 --- a/skops/io/_scipy.py +++ b/skops/io/_scipy.py @@ -12,6 +12,7 @@ def sparse_matrix_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), + "__loader__": "sparse_matrix_get_instance", } data_buffer = io.BytesIO() @@ -49,8 +50,8 @@ def sparse_matrix_get_instance(state, src): (spmatrix, sparse_matrix_get_state), ] # tuples of type and function that creates the instance of that type -GET_INSTANCE_DISPATCH_FUNCTIONS = [ +GET_INSTANCE_DISPATCH_MAPPING = { # use 'spmatrix' to check if a matrix is a sparse matrix because that is # what scipy.sparse.issparse checks - (spmatrix, sparse_matrix_get_instance), -] + "sparse_matrix_get_instance": sparse_matrix_get_instance, +} diff --git a/skops/io/_sklearn.py b/skops/io/_sklearn.py index d81c54d1..6cde5463 100644 --- a/skops/io/_sklearn.py +++ b/skops/io/_sklearn.py @@ -21,10 +21,10 @@ SquaredLoss, ) from sklearn.tree._tree import Tree -from sklearn.utils import Bunch +from ._dispatch import get_instance from ._general import dict_get_instance, dict_get_state, unsupported_get_state -from ._utils import SaveState, get_instance, get_module, get_state, gettype +from ._utils import SaveState, get_module, get_state, gettype from .exceptions import UnsupportedTypeException ALLOWED_SGD_LOSSES = { @@ -110,10 +110,22 @@ def reduce_get_instance(state, src, constructor): return instance -def Tree_get_instance(state, src): +def tree_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: + state = reduce_get_state(obj, save_state) + state["__loader__"] = "tree_get_instance" + return state + + +def tree_get_instance(state, src): return reduce_get_instance(state, src, constructor=Tree) +def sgd_loss_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: + state = reduce_get_state(obj, save_state) + state["__loader__"] = "sgd_loss_get_instance" + return state + + def sgd_loss_get_instance(state, src): cls = gettype(state) if cls not in ALLOWED_SGD_LOSSES: @@ -121,12 +133,6 @@ def sgd_loss_get_instance(state, src): return reduce_get_instance(state, src, constructor=cls) -def bunch_get_instance(state, src): - # Bunch is just a wrapper for dict - content = dict_get_instance(state, src) - return Bunch(**content) - - # TODO: remove once support for sklearn<1.2 is dropped. def _DictWithDeprecatedKeys_get_state( obj: Any, save_state: SaveState @@ -134,6 +140,7 @@ def _DictWithDeprecatedKeys_get_state( res = { "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), + "__loader__": "_DictWithDeprecatedKeys_get_instance", } content = {} content["main"] = dict_get_state(obj, save_state) @@ -158,18 +165,17 @@ def _DictWithDeprecatedKeys_get_instance(state, src): # tuples of type and function that gets the state of that type GET_STATE_DISPATCH_FUNCTIONS = [ - (LossFunction, reduce_get_state), - (Tree, reduce_get_state), + (LossFunction, sgd_loss_get_state), + (Tree, tree_get_state), ] for type_ in UNSUPPORTED_TYPES: GET_STATE_DISPATCH_FUNCTIONS.append((type_, unsupported_get_state)) # tuples of type and function that creates the instance of that type -GET_INSTANCE_DISPATCH_FUNCTIONS = [ - (LossFunction, sgd_loss_get_instance), - (Tree, Tree_get_instance), - (Bunch, bunch_get_instance), -] +GET_INSTANCE_DISPATCH_MAPPING = { + "sgd_loss_get_instance": sgd_loss_get_instance, + "tree_get_instance": tree_get_instance, +} # TODO: remove once support for sklearn<1.2 is dropped. # Starting from sklearn 1.2, _DictWithDeprecatedKeys is removed as it's no @@ -178,6 +184,6 @@ def _DictWithDeprecatedKeys_get_instance(state, src): GET_STATE_DISPATCH_FUNCTIONS.append( (_DictWithDeprecatedKeys, _DictWithDeprecatedKeys_get_state) ) - GET_INSTANCE_DISPATCH_FUNCTIONS.append( - (_DictWithDeprecatedKeys, _DictWithDeprecatedKeys_get_instance) - ) + GET_INSTANCE_DISPATCH_MAPPING[ + "_DictWithDeprecatedKeys_get_instance" + ] = _DictWithDeprecatedKeys_get_instance diff --git a/skops/io/_utils.py b/skops/io/_utils.py index 7396f683..82dc1a92 100644 --- a/skops/io/_utils.py +++ b/skops/io/_utils.py @@ -4,138 +4,11 @@ import json # type: ignore import sys from dataclasses import dataclass, field -from functools import _find_impl, get_cache_token, update_wrapper # type: ignore +from functools import singledispatch from types import FunctionType from typing import Any from zipfile import ZipFile -from skops.utils.fixes import GenericAlias - - -# This is an almost 1:1 copy of functools.singledispatch. There is one crucial -# difference, however. Usually, we want to dispatch on the class of the object. -# However, when we call get_instance, the object is *always* a dict, which -# invalidates the dispatch. Therefore, we change the dispatcher to dispatch on -# the instance, not the class. By default, we just use the class of the instance -# being passed, i.e. we do exactly the same as in the original implementation. -# However, if we encounter a state dict, we resolve the actual class from the -# state dict first and then dispatch on that class. The changed lines are marked -# as "# CHANGED". -# fmt: off -def singledispatch(func): - """Single-dispatch generic function decorator. - - Transforms a function into a generic function, which can have different - behaviours depending upon the type of its first argument. The decorated - function acts as the default implementation, and additional - implementations can be registered using the register() attribute of the - generic function. - """ - # There are many programs that use functools without singledispatch, so we - # trade-off making singledispatch marginally slower for the benefit of - # making start-up of such applications slightly faster. - import types - import weakref - - registry = {} - dispatch_cache = weakref.WeakKeyDictionary() - cache_token = None - - def dispatch(instance): # CHANGED: variable name cls->instance - """generic_func.dispatch(cls) -> - - Runs the dispatch algorithm to return the best available implementation - for the given *cls* registered on *generic_func*. - - """ - # CHANGED: check if we deal with a state dict, in which case we use it - # to resolve the correct class. Otherwise, just use the class of the - # instance. - if ( - isinstance(instance, dict) - and "__module__" in instance - and "__class__" in instance - ): - cls = gettype(instance) - else: - cls = instance.__class__ - - nonlocal cache_token - if cache_token is not None: - current_token = get_cache_token() - if cache_token != current_token: - dispatch_cache.clear() - cache_token = current_token - try: - impl = dispatch_cache[cls] - except KeyError: - try: - impl = registry[cls] - except KeyError: - impl = _find_impl(cls, registry) - dispatch_cache[cls] = impl - return impl - - def _is_valid_dispatch_type(cls): - return isinstance(cls, type) and not isinstance(cls, GenericAlias) - - def register(cls, func=None): - """generic_func.register(cls, func) -> func - - Registers a new implementation for the given *cls* on a *generic_func*. - - """ - nonlocal cache_token - if _is_valid_dispatch_type(cls): - if func is None: - return lambda f: register(cls, f) - else: - if func is not None: - raise TypeError( - f"Invalid first argument to `register()`. " - f"{cls!r} is not a class." - ) - ann = getattr(cls, '__annotations__', {}) - if not ann: - raise TypeError( - f"Invalid first argument to `register()`: {cls!r}. " - f"Use either `@register(some_class)` or plain `@register` " - f"on an annotated function." - ) - func = cls - # only import typing if annotation parsing is necessary - from typing import get_type_hints - argname, cls = next(iter(get_type_hints(func).items())) - if not _is_valid_dispatch_type(cls): - raise TypeError( - f"Invalid annotation for {argname!r}. " - f"{cls!r} is not a class." - ) - - registry[cls] = func - if cache_token is None and hasattr(cls, '__abstractmethods__'): - cache_token = get_cache_token() - dispatch_cache.clear() - return func - - def wrapper(*args, **kw): - if not args: - raise TypeError(f'{funcname} requires at least ' - '1 positional argument') - - # CHANGED: dispatch on instance, not class - return dispatch(args[0])(*args, **kw) - - funcname = getattr(func, '__name__', 'singledispatch function') - registry[object] = func - wrapper.register = register - wrapper.dispatch = dispatch - wrapper.registry = types.MappingProxyType(registry) - wrapper._clear_cache = dispatch_cache.clear - update_wrapper(wrapper, func) - return wrapper -# fmt: on - # The following two functions are copied from cpython's pickle.py file. # --------------------------------------------------------------------- @@ -258,41 +131,21 @@ def clear_memo(self) -> None: @singledispatch -def _get_state(obj, dst): +def _get_state(obj, save_state): # This function should never be called directly. Instead, it is used to # dispatch to the correct implementation of get_state for the given type of # its first argument. raise TypeError(f"Getting the state of type {type(obj)} is not supported yet") -@singledispatch -def _get_instance(obj, src): - # This function should never be called directly. Instead, it is used to - # dispatch to the correct implementation of get_instance for the given type - # of its first argument. - raise TypeError(f"Creating an instance of type {type(obj)} is not supported yet") - - -def get_state(value, dst): +def get_state(value, save_state): # This is a helper function to try to get the state of an object. If it # fails with `get_state`, we try with json.dumps, if that fails, we raise # the original error alongside the json error. try: - return _get_state(value, dst) + return _get_state(value, save_state) except TypeError as e1: try: return json.dumps(value) except Exception as e2: raise e1 from e2 - - -def get_instance(value, src): - # This is a helper function to try to get the state of an object. If - # `gettype` fails, we load with `json`. - if value is None: - return None - - if gettype(value): - return _get_instance(value, src) - - return json.loads(value) diff --git a/skops/io/tests/test_persist.py b/skops/io/tests/test_persist.py index acd7c564..54ce4faf 100644 --- a/skops/io/tests/test_persist.py +++ b/skops/io/tests/test_persist.py @@ -51,8 +51,9 @@ import skops from skops.io import dump, dumps, load, loads +from skops.io._dispatch import GET_INSTANCE_MAPPING, get_instance from skops.io._sklearn import UNSUPPORTED_TYPES -from skops.io._utils import _get_instance, _get_state +from skops.io._utils import _get_state, get_state from skops.io.exceptions import UnsupportedTypeException # Default settings for X @@ -64,7 +65,7 @@ ATOL = 1e-6 if sys.platform == "darwin" else 1e-7 -@pytest.fixture(autouse=True) +@pytest.fixture(autouse=True, scope="module") def debug_dispatch_functions(): # Patch the get_state and get_instance methods to add some sanity checks on # them. Specifically, we test that the arguments of the functions all follow @@ -81,12 +82,10 @@ def debug_get_state(func): def wrapper(obj, save_state): result = func(obj, save_state) - if isinstance(result, dict): - assert "__class__" in result - assert "__module__" in result - else: - # should be a primitive type - assert isinstance(result, (int, float, str)) + assert "__class__" in result + assert "__module__" in result + assert "__loader__" in result + return result return wrapper @@ -98,16 +97,12 @@ def debug_get_instance(func): @wraps(func) def wrapper(state, src): - if isinstance(state, dict): - assert "__class__" in state - assert "__module__" in state - else: - # should be a primitive type - assert isinstance(state, (int, float, str)) - assert (src is None) or isinstance(src, ZipFile) + assert "__class__" in state + assert "__module__" in state + assert "__loader__" in state + assert isinstance(src, ZipFile) result = func(state, src) - return result return wrapper @@ -118,8 +113,8 @@ def wrapper(state, src): module = importlib.import_module(module_name, package="skops.io") for cls, method in getattr(module, "GET_STATE_DISPATCH_FUNCTIONS", []): _get_state.register(cls)(debug_get_state(method)) - for cls, method in getattr(module, "GET_INSTANCE_DISPATCH_FUNCTIONS", []): - _get_instance.register(cls)(debug_get_instance(method)) + for key, method in GET_INSTANCE_MAPPING.copy().items(): + GET_INSTANCE_MAPPING[key] = debug_get_instance(method) def save_load_round(estimator, f_name, dump_method="fs"): @@ -804,3 +799,11 @@ def test_loads_from_str(): msg = "Can't load skops format from string, pass bytes" with pytest.raises(TypeError, match=msg): loads("this is a string") + + +def test_get_instance_unknown_type_error_msg(): + state = get_state(("hi", [123]), None) + state["__loader__"] = "this_get_instance_does_not_exist" + msg = "Can't find loader this_get_instance_does_not_exist for type builtins.tuple." + with pytest.raises(TypeError, match=msg): + get_instance(state, None) diff --git a/skops/utils/fixes.py b/skops/utils/fixes.py index 4cf39265..e9d83558 100644 --- a/skops/utils/fixes.py +++ b/skops/utils/fixes.py @@ -4,7 +4,6 @@ import sys from contextlib import suppress from pathlib import Path -from typing import List if sys.version_info >= (3, 8): # py>=3.8 @@ -22,11 +21,6 @@ # _min_dependencies.py from typing_extensions import Literal # noqa -if sys.version_info >= (3, 9): - from types import GenericAlias -else: - GenericAlias = type(List[int]) - def path_unlink(path: Path, missing_ok: bool = False) -> None: """Remove this file or symbolic link