Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions skops/io/_dispatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from __future__ import annotations

import json

GET_INSTANCE_MAPPING = {} # type: ignore


def get_instance(state, src):
"""Create instance based on the state, using json if possible"""
if state.get("is_json"):
return json.loads(state["content"])

try:
get_instance_func = GET_INSTANCE_MAPPING[state["__loader__"]]
except KeyError:
type_name = f"{state['__module__']}.{state['__class__']}"
raise TypeError(
f" Can't find loader {state['__loader__']} for type {type_name}."
)
return get_instance_func(state, src)
34 changes: 22 additions & 12 deletions skops/io/_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@

import numpy as np

from ._utils import SaveState, _import_obj, get_instance, get_module, get_state, gettype
from ._dispatch import get_instance
from ._utils import SaveState, _import_obj, get_module, get_state, gettype
from .exceptions import UnsupportedTypeException


def dict_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
"__loader__": "dict_get_instance",
}

key_types = get_state([type(key) for key in obj.keys()], save_state)
Expand Down Expand Up @@ -43,6 +45,7 @@ def list_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
"__loader__": "list_get_instance",
}
content = []
for value in obj:
Expand All @@ -62,6 +65,7 @@ def tuple_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
"__loader__": "tuple_get_instance",
}
content = tuple(get_state(value, save_state) for value in obj)
res["content"] = content
Expand Down Expand Up @@ -93,6 +97,7 @@ def function_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(obj),
"__loader__": "function_get_instance",
"content": {
"module_path": get_module(obj),
"function": obj.__name__,
Expand All @@ -111,6 +116,7 @@ def partial_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
res = {
"__class__": "partial", # don't allow any subclass
"__module__": get_module(type(obj)),
"__loader__": "partial_get_instance",
"content": {
"func": get_state(func, save_state),
"args": get_state(args, save_state),
Expand Down Expand Up @@ -138,6 +144,7 @@ def type_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
"__loader__": "type_get_instance",
"content": {
"__class__": obj.__name__,
"__module__": get_module(obj),
Expand All @@ -155,6 +162,7 @@ def slice_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
"__loader__": "slice_get_instance",
"content": {
"start": obj.start,
"stop": obj.stop,
Expand All @@ -181,6 +189,7 @@ def object_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
return {
"__class__": "str",
"__module__": "builtins",
"__loader__": "none",
"content": obj_str,
"is_json": True,
}
Expand All @@ -190,6 +199,7 @@ def object_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
"__loader__": "object_get_instance",
}

# __getstate__ takes priority over __dict__, and if non exist, we only save
Expand Down Expand Up @@ -247,14 +257,14 @@ def unsupported_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
(type, type_get_state),
(object, object_get_state),
]
# tuples of type and function that creates the instance of that type
GET_INSTANCE_DISPATCH_FUNCTIONS = [
(dict, dict_get_instance),
(list, list_get_instance),
(tuple, tuple_get_instance),
(slice, slice_get_instance),
(FunctionType, function_get_instance),
(partial, partial_get_instance),
(type, type_get_instance),
(object, object_get_instance),
]

GET_INSTANCE_DISPATCH_MAPPING = {
"dict_get_instance": dict_get_instance,
"list_get_instance": list_get_instance,
"tuple_get_instance": tuple_get_instance,
"slice_get_instance": slice_get_instance,
"function_get_instance": function_get_instance,
"partial_get_instance": partial_get_instance,
"type_get_instance": type_get_instance,
"object_get_instance": object_get_instance,
}
26 changes: 16 additions & 10 deletions skops/io/_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@

import numpy as np

from ._dispatch import get_instance
from ._general import function_get_instance
from ._utils import SaveState, _import_obj, get_instance, get_module, get_state
from ._utils import SaveState, _import_obj, get_module, get_state
from .exceptions import UnsupportedTypeException


def ndarray_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
"__loader__": "ndarray_get_instance",
}

try:
Expand Down Expand Up @@ -78,6 +80,7 @@ def maskedarray_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
"__loader__": "maskedarray_get_instance",
"content": {
"data": get_state(obj.data, save_state),
"mask": get_state(obj.mask, save_state),
Expand All @@ -97,6 +100,7 @@ def random_state_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
"__loader__": "random_state_get_instance",
"content": content,
}
return res
Expand All @@ -115,6 +119,7 @@ def random_generator_get_state(obj: Any, save_state: SaveState) -> dict[str, Any
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
"__loader__": "random_generator_get_instance",
"content": {"bit_generator": bit_generator_state},
}
return res
Expand All @@ -139,6 +144,7 @@ def ufunc_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__, # ufunc
"__module__": get_module(type(obj)), # numpy
"__loader__": "function_get_instance",
"content": {
"module_path": get_module(obj),
"function": obj.__name__,
Expand All @@ -154,6 +160,7 @@ def dtype_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
res = {
"__class__": "dtype",
"__module__": "numpy",
"__loader__": "dtype_get_instance",
"content": ndarray_get_state(tmp, save_state),
}
return res
Expand All @@ -177,12 +184,11 @@ def dtype_get_instance(state, src):
(np.random.Generator, random_generator_get_state),
]
# tuples of type and function that creates the instance of that type
GET_INSTANCE_DISPATCH_FUNCTIONS = [
(np.generic, ndarray_get_instance),
(np.ndarray, ndarray_get_instance),
(np.ma.MaskedArray, maskedarray_get_instance),
(np.ufunc, function_get_instance),
(np.dtype, dtype_get_instance),
(np.random.RandomState, random_state_get_instance),
(np.random.Generator, random_generator_get_instance),
]
GET_INSTANCE_DISPATCH_MAPPING = {
"ndarray_get_instance": ndarray_get_instance,
"maskedarray_get_instance": maskedarray_get_instance,
"function_get_instance": function_get_instance,
"dtype_get_instance": dtype_get_instance,
"random_state_get_instance": random_state_get_instance,
"random_generator_get_instance": random_generator_get_instance,
}
7 changes: 4 additions & 3 deletions skops/io/_persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

import skops

from ._utils import SaveState, _get_instance, _get_state, get_instance, get_state
from ._dispatch import GET_INSTANCE_MAPPING, get_instance
from ._utils import SaveState, _get_state, get_state

# We load the dispatch functions from the corresponding modules and register
# them.
Expand All @@ -17,8 +18,8 @@
module = importlib.import_module(module_name, package="skops.io")
for cls, method in getattr(module, "GET_STATE_DISPATCH_FUNCTIONS", []):
_get_state.register(cls)(method)
for cls, method in getattr(module, "GET_INSTANCE_DISPATCH_FUNCTIONS", []):
_get_instance.register(cls)(method)
# populate the the dict used for dispatching get_instance functions
GET_INSTANCE_MAPPING.update(module.GET_INSTANCE_DISPATCH_MAPPING)


def _save(obj):
Expand Down
7 changes: 4 additions & 3 deletions skops/io/_scipy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def sparse_matrix_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
"__loader__": "sparse_matrix_get_instance",
}

data_buffer = io.BytesIO()
Expand Down Expand Up @@ -49,8 +50,8 @@ def sparse_matrix_get_instance(state, src):
(spmatrix, sparse_matrix_get_state),
]
# tuples of type and function that creates the instance of that type
GET_INSTANCE_DISPATCH_FUNCTIONS = [
GET_INSTANCE_DISPATCH_MAPPING = {
# use 'spmatrix' to check if a matrix is a sparse matrix because that is
# what scipy.sparse.issparse checks
(spmatrix, sparse_matrix_get_instance),
]
"sparse_matrix_get_instance": sparse_matrix_get_instance,
}
44 changes: 25 additions & 19 deletions skops/io/_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
SquaredLoss,
)
from sklearn.tree._tree import Tree
from sklearn.utils import Bunch
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is because we load and create the class of the object whatever it was now in dict_get_instance instead of creating a dict.


from ._dispatch import get_instance
from ._general import dict_get_instance, dict_get_state, unsupported_get_state
from ._utils import SaveState, get_instance, get_module, get_state, gettype
from ._utils import SaveState, get_module, get_state, gettype
from .exceptions import UnsupportedTypeException

ALLOWED_SGD_LOSSES = {
Expand Down Expand Up @@ -110,30 +110,37 @@ def reduce_get_instance(state, src, constructor):
return instance


def Tree_get_instance(state, src):
def tree_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
state = reduce_get_state(obj, save_state)
state["__loader__"] = "tree_get_instance"
return state


def tree_get_instance(state, src):
return reduce_get_instance(state, src, constructor=Tree)


def sgd_loss_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
state = reduce_get_state(obj, save_state)
state["__loader__"] = "sgd_loss_get_instance"
return state


def sgd_loss_get_instance(state, src):
cls = gettype(state)
if cls not in ALLOWED_SGD_LOSSES:
raise UnsupportedTypeException(f"Expected LossFunction, got {cls}")
return reduce_get_instance(state, src, constructor=cls)


def bunch_get_instance(state, src):
# Bunch is just a wrapper for dict
content = dict_get_instance(state, src)
return Bunch(**content)


# TODO: remove once support for sklearn<1.2 is dropped.
def _DictWithDeprecatedKeys_get_state(
obj: Any, save_state: SaveState
) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
"__loader__": "_DictWithDeprecatedKeys_get_instance",
}
content = {}
content["main"] = dict_get_state(obj, save_state)
Expand All @@ -158,18 +165,17 @@ def _DictWithDeprecatedKeys_get_instance(state, src):

# tuples of type and function that gets the state of that type
GET_STATE_DISPATCH_FUNCTIONS = [
(LossFunction, reduce_get_state),
(Tree, reduce_get_state),
(LossFunction, sgd_loss_get_state),
(Tree, tree_get_state),
]
for type_ in UNSUPPORTED_TYPES:
GET_STATE_DISPATCH_FUNCTIONS.append((type_, unsupported_get_state))

# tuples of type and function that creates the instance of that type
GET_INSTANCE_DISPATCH_FUNCTIONS = [
(LossFunction, sgd_loss_get_instance),
(Tree, Tree_get_instance),
(Bunch, bunch_get_instance),
]
GET_INSTANCE_DISPATCH_MAPPING = {
"sgd_loss_get_instance": sgd_loss_get_instance,
"tree_get_instance": tree_get_instance,
}

# TODO: remove once support for sklearn<1.2 is dropped.
# Starting from sklearn 1.2, _DictWithDeprecatedKeys is removed as it's no
Expand All @@ -178,6 +184,6 @@ def _DictWithDeprecatedKeys_get_instance(state, src):
GET_STATE_DISPATCH_FUNCTIONS.append(
(_DictWithDeprecatedKeys, _DictWithDeprecatedKeys_get_state)
)
GET_INSTANCE_DISPATCH_FUNCTIONS.append(
(_DictWithDeprecatedKeys, _DictWithDeprecatedKeys_get_instance)
)
GET_INSTANCE_DISPATCH_MAPPING[
"_DictWithDeprecatedKeys_get_instance"
] = _DictWithDeprecatedKeys_get_instance
Loading