From abe490db995ed05249ad41320437c206748e4c05 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Fri, 21 Oct 2022 16:36:54 +0200 Subject: [PATCH 1/5] Refactor: get_instance method saved in state Resolves #197 Description Currently, during the dispatch of the get_instance functions, the class stored in the state is being loaded to determine which function to dispatch to. This is bad because module loading can be dangerous. We will add auditing but it is intended to be on the level of get_instance itself, not for the dispatch mechanism. In this PR, the state returned by get_state functions is augmented with the name of the get_instance method required to load the object. This way, we can look up the correct method based on the state and don't need to use the modified singledispatch mechanism, thus avoiding loading modules during dispatching. Implementation Whereas for get_state, we still rely in singledispatch, for get_instance we now use a simple dictionary that looks up the function based on its name (which in turn is stored in the state). The dictionary, going by the name of GET_INSTANCE_MAPPING, is populated similarly to how the get_instance functions were registered previously with singledispatch. There was an issue with circular imports (e.g. get_instance > GET_INSTANCE_MAPPING > ndarray_get_instance > get_instance), hence the get_instance function was moved to its own module, _dispatch.py (other name suggestions welcome). For some types, we now need extra get_state functions because they have specific get_instance methods. So e.g. sgd_loss_get_state just wraps reduce_get_state and adds sgd_loss_get_instance as its loader. Coincidental changes Since we no longer have to inspect the contents of state to determine the function to dispatch to for get_instance, we can fall back to the Python implementation of singledispatch instead of rolling our own. This side effect is a big win. The function Tree_get_instance was renamed to tree_get_instance for consistency. In the debug_dispatch_functions, there was some code from a time when the state was allowed not to be a dict (json-serializable objects). Now we always have a dict, so this dead code was removed. Also, this fixture was elevated to module-level scope, since it only needs to run once. --- skops/io/_dispatch.py | 21 +++++ skops/io/_general.py | 34 +++++--- skops/io/_numpy.py | 26 +++--- skops/io/_persist.py | 7 +- skops/io/_scipy.py | 7 +- skops/io/_sklearn.py | 44 +++++++--- skops/io/_utils.py | 149 +-------------------------------- skops/io/tests/test_persist.py | 31 +++---- skops/utils/fixes.py | 6 -- 9 files changed, 113 insertions(+), 212 deletions(-) create mode 100644 skops/io/_dispatch.py diff --git a/skops/io/_dispatch.py b/skops/io/_dispatch.py new file mode 100644 index 00000000..3b1aad5a --- /dev/null +++ b/skops/io/_dispatch.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python3 + +import json +from typing import Any, Callable +from zipfile import ZipFile + +GET_INSTANCE_MAPPING: dict[str, Callable[[dict[str, Any], ZipFile], Any]] = {} + + +def get_instance(state, src): + """Create instance based on the state, using json if possible""" + if state.get("is_json"): + return json.loads(state["content"]) + + try: + get_instance_func = GET_INSTANCE_MAPPING[state["__loader__"]] + except KeyError: + raise TypeError( + f"Creating an instance of type {type(state)} is not supported yet" + ) + return get_instance_func(state, src) diff --git a/skops/io/_general.py b/skops/io/_general.py index 88413472..d32d63ae 100644 --- a/skops/io/_general.py +++ b/skops/io/_general.py @@ -7,7 +7,8 @@ import numpy as np -from ._utils import SaveState, _import_obj, get_instance, get_module, get_state, gettype +from ._dispatch import get_instance +from ._utils import SaveState, _import_obj, get_module, get_state, gettype from .exceptions import UnsupportedTypeException @@ -15,6 +16,7 @@ def dict_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), + "__loader__": "dict_get_instance", } key_types = get_state([type(key) for key in obj.keys()], save_state) @@ -43,6 +45,7 @@ def list_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), + "__loader__": "list_get_instance", } content = [] for value in obj: @@ -62,6 +65,7 @@ def tuple_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), + "__loader__": "tuple_get_instance", } content = tuple(get_state(value, save_state) for value in obj) res["content"] = content @@ -93,6 +97,7 @@ def function_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, "__module__": get_module(obj), + "__loader__": "function_get_instance", "content": { "module_path": get_module(obj), "function": obj.__name__, @@ -111,6 +116,7 @@ def partial_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: res = { "__class__": "partial", # don't allow any subclass "__module__": get_module(type(obj)), + "__loader__": "partial_get_instance", "content": { "func": get_state(func, save_state), "args": get_state(args, save_state), @@ -138,6 +144,7 @@ def type_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), + "__loader__": "type_get_instance", "content": { "__class__": obj.__name__, "__module__": get_module(obj), @@ -155,6 +162,7 @@ def slice_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), + "__loader__": "slice_get_instance", "content": { "start": obj.start, "stop": obj.stop, @@ -181,6 +189,7 @@ def object_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return { "__class__": "str", "__module__": "builtins", + "__loader__": "none", "content": obj_str, "is_json": True, } @@ -190,6 +199,7 @@ def object_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), + "__loader__": "object_get_instance", } # __getstate__ takes priority over __dict__, and if non exist, we only save @@ -247,14 +257,14 @@ def unsupported_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: (type, type_get_state), (object, object_get_state), ] -# tuples of type and function that creates the instance of that type -GET_INSTANCE_DISPATCH_FUNCTIONS = [ - (dict, dict_get_instance), - (list, list_get_instance), - (tuple, tuple_get_instance), - (slice, slice_get_instance), - (FunctionType, function_get_instance), - (partial, partial_get_instance), - (type, type_get_instance), - (object, object_get_instance), -] + +GET_INSTANCE_DISPATCH_MAPPING = { + "dict_get_instance": dict_get_instance, + "list_get_instance": list_get_instance, + "tuple_get_instance": tuple_get_instance, + "slice_get_instance": slice_get_instance, + "function_get_instance": function_get_instance, + "partial_get_instance": partial_get_instance, + "type_get_instance": type_get_instance, + "object_get_instance": object_get_instance, +} diff --git a/skops/io/_numpy.py b/skops/io/_numpy.py index cd006c2a..4d1d5b98 100644 --- a/skops/io/_numpy.py +++ b/skops/io/_numpy.py @@ -5,8 +5,9 @@ import numpy as np +from ._dispatch import get_instance from ._general import function_get_instance -from ._utils import SaveState, _import_obj, get_instance, get_module, get_state +from ._utils import SaveState, _import_obj, get_module, get_state from .exceptions import UnsupportedTypeException @@ -14,6 +15,7 @@ def ndarray_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), + "__loader__": "ndarray_get_instance", } try: @@ -78,6 +80,7 @@ def maskedarray_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), + "__loader__": "maskedarray_get_instance", "content": { "data": get_state(obj.data, save_state), "mask": get_state(obj.mask, save_state), @@ -97,6 +100,7 @@ def random_state_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), + "__loader__": "random_state_get_instance", "content": content, } return res @@ -115,6 +119,7 @@ def random_generator_get_state(obj: Any, save_state: SaveState) -> dict[str, Any res = { "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), + "__loader__": "random_generator_get_instance", "content": {"bit_generator": bit_generator_state}, } return res @@ -139,6 +144,7 @@ def ufunc_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, # ufunc "__module__": get_module(type(obj)), # numpy + "__loader__": "function_get_instance", "content": { "module_path": get_module(obj), "function": obj.__name__, @@ -154,6 +160,7 @@ def dtype_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: res = { "__class__": "dtype", "__module__": "numpy", + "__loader__": "dtype_get_instance", "content": ndarray_get_state(tmp, save_state), } return res @@ -177,12 +184,11 @@ def dtype_get_instance(state, src): (np.random.Generator, random_generator_get_state), ] # tuples of type and function that creates the instance of that type -GET_INSTANCE_DISPATCH_FUNCTIONS = [ - (np.generic, ndarray_get_instance), - (np.ndarray, ndarray_get_instance), - (np.ma.MaskedArray, maskedarray_get_instance), - (np.ufunc, function_get_instance), - (np.dtype, dtype_get_instance), - (np.random.RandomState, random_state_get_instance), - (np.random.Generator, random_generator_get_instance), -] +GET_INSTANCE_DISPATCH_MAPPING = { + "ndarray_get_instance": ndarray_get_instance, + "maskedarray_get_instance": maskedarray_get_instance, + "function_get_instance": function_get_instance, + "dtype_get_instance": dtype_get_instance, + "random_state_get_instance": random_state_get_instance, + "random_generator_get_instance": random_generator_get_instance, +} diff --git a/skops/io/_persist.py b/skops/io/_persist.py index 5557353b..1906eca2 100644 --- a/skops/io/_persist.py +++ b/skops/io/_persist.py @@ -7,7 +7,8 @@ import skops -from ._utils import SaveState, _get_instance, _get_state, get_instance, get_state +from ._dispatch import GET_INSTANCE_MAPPING, get_instance +from ._utils import SaveState, _get_state, get_state # We load the dispatch functions from the corresponding modules and register # them. @@ -17,8 +18,8 @@ module = importlib.import_module(module_name, package="skops.io") for cls, method in getattr(module, "GET_STATE_DISPATCH_FUNCTIONS", []): _get_state.register(cls)(method) - for cls, method in getattr(module, "GET_INSTANCE_DISPATCH_FUNCTIONS", []): - _get_instance.register(cls)(method) + # populate the the dict used for dispatching get_instance functions + GET_INSTANCE_MAPPING.update(module.GET_INSTANCE_DISPATCH_MAPPING) def _save(obj): diff --git a/skops/io/_scipy.py b/skops/io/_scipy.py index 11457ec9..305d30dc 100644 --- a/skops/io/_scipy.py +++ b/skops/io/_scipy.py @@ -12,6 +12,7 @@ def sparse_matrix_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), + "__loader__": "sparse_matrix_get_instance", } data_buffer = io.BytesIO() @@ -49,8 +50,8 @@ def sparse_matrix_get_instance(state, src): (spmatrix, sparse_matrix_get_state), ] # tuples of type and function that creates the instance of that type -GET_INSTANCE_DISPATCH_FUNCTIONS = [ +GET_INSTANCE_DISPATCH_MAPPING = { # use 'spmatrix' to check if a matrix is a sparse matrix because that is # what scipy.sparse.issparse checks - (spmatrix, sparse_matrix_get_instance), -] + "sparse_matrix_get_instance": sparse_matrix_get_instance, +} diff --git a/skops/io/_sklearn.py b/skops/io/_sklearn.py index d81c54d1..a1a6939e 100644 --- a/skops/io/_sklearn.py +++ b/skops/io/_sklearn.py @@ -23,8 +23,9 @@ from sklearn.tree._tree import Tree from sklearn.utils import Bunch +from ._dispatch import get_instance from ._general import dict_get_instance, dict_get_state, unsupported_get_state -from ._utils import SaveState, get_instance, get_module, get_state, gettype +from ._utils import SaveState, get_module, get_state, gettype from .exceptions import UnsupportedTypeException ALLOWED_SGD_LOSSES = { @@ -110,10 +111,22 @@ def reduce_get_instance(state, src, constructor): return instance -def Tree_get_instance(state, src): +def tree_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: + state = reduce_get_state(obj, save_state) + state["__loader__"] = "tree_get_instance" + return state + + +def tree_get_instance(state, src): return reduce_get_instance(state, src, constructor=Tree) +def sgd_loss_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: + state = reduce_get_state(obj, save_state) + state["__loader__"] = "sgd_loss_get_instance" + return state + + def sgd_loss_get_instance(state, src): cls = gettype(state) if cls not in ALLOWED_SGD_LOSSES: @@ -121,6 +134,12 @@ def sgd_loss_get_instance(state, src): return reduce_get_instance(state, src, constructor=cls) +def bunch_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: + state = dict_get_state(obj, save_state) + state["__loader__"] = "bunch_get_instance" + return state + + def bunch_get_instance(state, src): # Bunch is just a wrapper for dict content = dict_get_instance(state, src) @@ -134,6 +153,7 @@ def _DictWithDeprecatedKeys_get_state( res = { "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), + "__loader__": "_DictWithDeprecatedKeys_get_instance", } content = {} content["main"] = dict_get_state(obj, save_state) @@ -158,18 +178,18 @@ def _DictWithDeprecatedKeys_get_instance(state, src): # tuples of type and function that gets the state of that type GET_STATE_DISPATCH_FUNCTIONS = [ - (LossFunction, reduce_get_state), - (Tree, reduce_get_state), + (LossFunction, sgd_loss_get_state), + (Tree, tree_get_state), ] for type_ in UNSUPPORTED_TYPES: GET_STATE_DISPATCH_FUNCTIONS.append((type_, unsupported_get_state)) # tuples of type and function that creates the instance of that type -GET_INSTANCE_DISPATCH_FUNCTIONS = [ - (LossFunction, sgd_loss_get_instance), - (Tree, Tree_get_instance), - (Bunch, bunch_get_instance), -] +GET_INSTANCE_DISPATCH_MAPPING = { + "sgd_loss_get_instance": sgd_loss_get_instance, + "tree_get_instance": tree_get_instance, + "bunch_get_instance": bunch_get_instance, +} # TODO: remove once support for sklearn<1.2 is dropped. # Starting from sklearn 1.2, _DictWithDeprecatedKeys is removed as it's no @@ -178,6 +198,6 @@ def _DictWithDeprecatedKeys_get_instance(state, src): GET_STATE_DISPATCH_FUNCTIONS.append( (_DictWithDeprecatedKeys, _DictWithDeprecatedKeys_get_state) ) - GET_INSTANCE_DISPATCH_FUNCTIONS.append( - (_DictWithDeprecatedKeys, _DictWithDeprecatedKeys_get_instance) - ) + GET_INSTANCE_DISPATCH_MAPPING[ + "_DictWithDeprecatedKeys_get_instance" + ] = _DictWithDeprecatedKeys_get_instance diff --git a/skops/io/_utils.py b/skops/io/_utils.py index 7396f683..c57f65d2 100644 --- a/skops/io/_utils.py +++ b/skops/io/_utils.py @@ -4,138 +4,11 @@ import json # type: ignore import sys from dataclasses import dataclass, field -from functools import _find_impl, get_cache_token, update_wrapper # type: ignore +from functools import singledispatch from types import FunctionType from typing import Any from zipfile import ZipFile -from skops.utils.fixes import GenericAlias - - -# This is an almost 1:1 copy of functools.singledispatch. There is one crucial -# difference, however. Usually, we want to dispatch on the class of the object. -# However, when we call get_instance, the object is *always* a dict, which -# invalidates the dispatch. Therefore, we change the dispatcher to dispatch on -# the instance, not the class. By default, we just use the class of the instance -# being passed, i.e. we do exactly the same as in the original implementation. -# However, if we encounter a state dict, we resolve the actual class from the -# state dict first and then dispatch on that class. The changed lines are marked -# as "# CHANGED". -# fmt: off -def singledispatch(func): - """Single-dispatch generic function decorator. - - Transforms a function into a generic function, which can have different - behaviours depending upon the type of its first argument. The decorated - function acts as the default implementation, and additional - implementations can be registered using the register() attribute of the - generic function. - """ - # There are many programs that use functools without singledispatch, so we - # trade-off making singledispatch marginally slower for the benefit of - # making start-up of such applications slightly faster. - import types - import weakref - - registry = {} - dispatch_cache = weakref.WeakKeyDictionary() - cache_token = None - - def dispatch(instance): # CHANGED: variable name cls->instance - """generic_func.dispatch(cls) -> - - Runs the dispatch algorithm to return the best available implementation - for the given *cls* registered on *generic_func*. - - """ - # CHANGED: check if we deal with a state dict, in which case we use it - # to resolve the correct class. Otherwise, just use the class of the - # instance. - if ( - isinstance(instance, dict) - and "__module__" in instance - and "__class__" in instance - ): - cls = gettype(instance) - else: - cls = instance.__class__ - - nonlocal cache_token - if cache_token is not None: - current_token = get_cache_token() - if cache_token != current_token: - dispatch_cache.clear() - cache_token = current_token - try: - impl = dispatch_cache[cls] - except KeyError: - try: - impl = registry[cls] - except KeyError: - impl = _find_impl(cls, registry) - dispatch_cache[cls] = impl - return impl - - def _is_valid_dispatch_type(cls): - return isinstance(cls, type) and not isinstance(cls, GenericAlias) - - def register(cls, func=None): - """generic_func.register(cls, func) -> func - - Registers a new implementation for the given *cls* on a *generic_func*. - - """ - nonlocal cache_token - if _is_valid_dispatch_type(cls): - if func is None: - return lambda f: register(cls, f) - else: - if func is not None: - raise TypeError( - f"Invalid first argument to `register()`. " - f"{cls!r} is not a class." - ) - ann = getattr(cls, '__annotations__', {}) - if not ann: - raise TypeError( - f"Invalid first argument to `register()`: {cls!r}. " - f"Use either `@register(some_class)` or plain `@register` " - f"on an annotated function." - ) - func = cls - # only import typing if annotation parsing is necessary - from typing import get_type_hints - argname, cls = next(iter(get_type_hints(func).items())) - if not _is_valid_dispatch_type(cls): - raise TypeError( - f"Invalid annotation for {argname!r}. " - f"{cls!r} is not a class." - ) - - registry[cls] = func - if cache_token is None and hasattr(cls, '__abstractmethods__'): - cache_token = get_cache_token() - dispatch_cache.clear() - return func - - def wrapper(*args, **kw): - if not args: - raise TypeError(f'{funcname} requires at least ' - '1 positional argument') - - # CHANGED: dispatch on instance, not class - return dispatch(args[0])(*args, **kw) - - funcname = getattr(func, '__name__', 'singledispatch function') - registry[object] = func - wrapper.register = register - wrapper.dispatch = dispatch - wrapper.registry = types.MappingProxyType(registry) - wrapper._clear_cache = dispatch_cache.clear - update_wrapper(wrapper, func) - return wrapper -# fmt: on - # The following two functions are copied from cpython's pickle.py file. # --------------------------------------------------------------------- @@ -265,14 +138,6 @@ def _get_state(obj, dst): raise TypeError(f"Getting the state of type {type(obj)} is not supported yet") -@singledispatch -def _get_instance(obj, src): - # This function should never be called directly. Instead, it is used to - # dispatch to the correct implementation of get_instance for the given type - # of its first argument. - raise TypeError(f"Creating an instance of type {type(obj)} is not supported yet") - - def get_state(value, dst): # This is a helper function to try to get the state of an object. If it # fails with `get_state`, we try with json.dumps, if that fails, we raise @@ -284,15 +149,3 @@ def get_state(value, dst): return json.dumps(value) except Exception as e2: raise e1 from e2 - - -def get_instance(value, src): - # This is a helper function to try to get the state of an object. If - # `gettype` fails, we load with `json`. - if value is None: - return None - - if gettype(value): - return _get_instance(value, src) - - return json.loads(value) diff --git a/skops/io/tests/test_persist.py b/skops/io/tests/test_persist.py index f470880a..7760b956 100644 --- a/skops/io/tests/test_persist.py +++ b/skops/io/tests/test_persist.py @@ -51,8 +51,9 @@ import skops from skops.io import dumps, load, loads, save +from skops.io._dispatch import GET_INSTANCE_MAPPING from skops.io._sklearn import UNSUPPORTED_TYPES -from skops.io._utils import _get_instance, _get_state +from skops.io._utils import _get_state from skops.io.exceptions import UnsupportedTypeException # Default settings for X @@ -64,7 +65,7 @@ ATOL = 1e-6 if sys.platform == "darwin" else 1e-7 -@pytest.fixture(autouse=True) +@pytest.fixture(autouse=True, scope="module") def debug_dispatch_functions(): # Patch the get_state and get_instance methods to add some sanity checks on # them. Specifically, we test that the arguments of the functions all follow @@ -81,12 +82,10 @@ def debug_get_state(func): def wrapper(obj, save_state): result = func(obj, save_state) - if isinstance(result, dict): - assert "__class__" in result - assert "__module__" in result - else: - # should be a primitive type - assert isinstance(result, (int, float, str)) + assert "__class__" in result + assert "__module__" in result + assert "__loader__" in result + return result return wrapper @@ -98,16 +97,12 @@ def debug_get_instance(func): @wraps(func) def wrapper(state, src): - if isinstance(state, dict): - assert "__class__" in state - assert "__module__" in state - else: - # should be a primitive type - assert isinstance(state, (int, float, str)) - assert (src is None) or isinstance(src, ZipFile) + assert "__class__" in state + assert "__module__" in state + assert "__loader__" in state + assert isinstance(src, ZipFile) result = func(state, src) - return result return wrapper @@ -118,8 +113,8 @@ def wrapper(state, src): module = importlib.import_module(module_name, package="skops.io") for cls, method in getattr(module, "GET_STATE_DISPATCH_FUNCTIONS", []): _get_state.register(cls)(debug_get_state(method)) - for cls, method in getattr(module, "GET_INSTANCE_DISPATCH_FUNCTIONS", []): - _get_instance.register(cls)(debug_get_instance(method)) + for key, method in GET_INSTANCE_MAPPING.copy().items(): + GET_INSTANCE_MAPPING[key] = debug_get_instance(method) def save_load_round(estimator, f_name, dump_method="fs"): diff --git a/skops/utils/fixes.py b/skops/utils/fixes.py index 4cf39265..e9d83558 100644 --- a/skops/utils/fixes.py +++ b/skops/utils/fixes.py @@ -4,7 +4,6 @@ import sys from contextlib import suppress from pathlib import Path -from typing import List if sys.version_info >= (3, 8): # py>=3.8 @@ -22,11 +21,6 @@ # _min_dependencies.py from typing_extensions import Literal # noqa -if sys.version_info >= (3, 9): - from types import GenericAlias -else: - GenericAlias = type(List[int]) - def path_unlink(path: Path, missing_ok: bool = False) -> None: """Remove this file or symbolic link From 0109007673479e279fc1fe7a144ef531804a1346 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Fri, 21 Oct 2022 17:00:05 +0200 Subject: [PATCH 2/5] Add annotations import for older Python versions --- skops/io/_dispatch.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/skops/io/_dispatch.py b/skops/io/_dispatch.py index 3b1aad5a..f5340420 100644 --- a/skops/io/_dispatch.py +++ b/skops/io/_dispatch.py @@ -1,5 +1,7 @@ #!/usr/bin/env python3 +from __future__ import annotations + import json from typing import Any, Callable from zipfile import ZipFile From 8314e18b616330be45101f94ca716214623c1476 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Fri, 21 Oct 2022 17:35:26 +0200 Subject: [PATCH 3/5] Remove special treatment of Bunch Tests pass without it. --- skops/io/_sklearn.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/skops/io/_sklearn.py b/skops/io/_sklearn.py index a1a6939e..6cde5463 100644 --- a/skops/io/_sklearn.py +++ b/skops/io/_sklearn.py @@ -21,7 +21,6 @@ SquaredLoss, ) from sklearn.tree._tree import Tree -from sklearn.utils import Bunch from ._dispatch import get_instance from ._general import dict_get_instance, dict_get_state, unsupported_get_state @@ -134,18 +133,6 @@ def sgd_loss_get_instance(state, src): return reduce_get_instance(state, src, constructor=cls) -def bunch_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: - state = dict_get_state(obj, save_state) - state["__loader__"] = "bunch_get_instance" - return state - - -def bunch_get_instance(state, src): - # Bunch is just a wrapper for dict - content = dict_get_instance(state, src) - return Bunch(**content) - - # TODO: remove once support for sklearn<1.2 is dropped. def _DictWithDeprecatedKeys_get_state( obj: Any, save_state: SaveState @@ -188,7 +175,6 @@ def _DictWithDeprecatedKeys_get_instance(state, src): GET_INSTANCE_DISPATCH_MAPPING = { "sgd_loss_get_instance": sgd_loss_get_instance, "tree_get_instance": tree_get_instance, - "bunch_get_instance": bunch_get_instance, } # TODO: remove once support for sklearn<1.2 is dropped. From 6ca266034836ef09afdb49b5f1fababe0fd66046 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Mon, 24 Oct 2022 12:42:22 +0200 Subject: [PATCH 4/5] Address reviewer comments - Remove type annotation for GET_INSTANCE_MAPPING - Better error message if loader not found Also changed: - Misleading var names in get_state --- skops/io/_dispatch.py | 7 +++---- skops/io/_utils.py | 6 +++--- skops/io/tests/test_persist.py | 12 ++++++++++-- 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/skops/io/_dispatch.py b/skops/io/_dispatch.py index f5340420..a64e5f15 100644 --- a/skops/io/_dispatch.py +++ b/skops/io/_dispatch.py @@ -3,10 +3,8 @@ from __future__ import annotations import json -from typing import Any, Callable -from zipfile import ZipFile -GET_INSTANCE_MAPPING: dict[str, Callable[[dict[str, Any], ZipFile], Any]] = {} +GET_INSTANCE_MAPPING = {} # type: ignore def get_instance(state, src): @@ -17,7 +15,8 @@ def get_instance(state, src): try: get_instance_func = GET_INSTANCE_MAPPING[state["__loader__"]] except KeyError: + type_name = f"{state['__module__']}.{state['__class__']}" raise TypeError( - f"Creating an instance of type {type(state)} is not supported yet" + f" Can't find loader {state['__loader__']} for type {type_name}." ) return get_instance_func(state, src) diff --git a/skops/io/_utils.py b/skops/io/_utils.py index c57f65d2..82dc1a92 100644 --- a/skops/io/_utils.py +++ b/skops/io/_utils.py @@ -131,19 +131,19 @@ def clear_memo(self) -> None: @singledispatch -def _get_state(obj, dst): +def _get_state(obj, save_state): # This function should never be called directly. Instead, it is used to # dispatch to the correct implementation of get_state for the given type of # its first argument. raise TypeError(f"Getting the state of type {type(obj)} is not supported yet") -def get_state(value, dst): +def get_state(value, save_state): # This is a helper function to try to get the state of an object. If it # fails with `get_state`, we try with json.dumps, if that fails, we raise # the original error alongside the json error. try: - return _get_state(value, dst) + return _get_state(value, save_state) except TypeError as e1: try: return json.dumps(value) diff --git a/skops/io/tests/test_persist.py b/skops/io/tests/test_persist.py index 3d23137e..54ce4faf 100644 --- a/skops/io/tests/test_persist.py +++ b/skops/io/tests/test_persist.py @@ -51,9 +51,9 @@ import skops from skops.io import dump, dumps, load, loads -from skops.io._dispatch import GET_INSTANCE_MAPPING +from skops.io._dispatch import GET_INSTANCE_MAPPING, get_instance from skops.io._sklearn import UNSUPPORTED_TYPES -from skops.io._utils import _get_state +from skops.io._utils import _get_state, get_state from skops.io.exceptions import UnsupportedTypeException # Default settings for X @@ -799,3 +799,11 @@ def test_loads_from_str(): msg = "Can't load skops format from string, pass bytes" with pytest.raises(TypeError, match=msg): loads("this is a string") + + +def test_get_instance_unknown_type_error_msg(): + state = get_state(("hi", [123]), None) + state["__loader__"] = "this_get_instance_does_not_exist" + msg = "Can't find loader this_get_instance_does_not_exist for type builtins.tuple." + with pytest.raises(TypeError, match=msg): + get_instance(state, None) From 82d0332c3891e7cba0034ec389d273fc647dd93b Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Mon, 24 Oct 2022 12:44:07 +0200 Subject: [PATCH 5/5] Remove unnecessary shebang --- skops/io/_dispatch.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/skops/io/_dispatch.py b/skops/io/_dispatch.py index a64e5f15..e0ae9a96 100644 --- a/skops/io/_dispatch.py +++ b/skops/io/_dispatch.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python3 - from __future__ import annotations import json