diff --git a/skops/io/_general.py b/skops/io/_general.py index d32d63ae..c0594319 100644 --- a/skops/io/_general.py +++ b/skops/io/_general.py @@ -2,7 +2,7 @@ import json from functools import partial -from types import FunctionType +from types import FunctionType, MethodType from typing import Any import numpy as np @@ -242,6 +242,30 @@ def object_get_instance(state, src): return instance +def method_get_state(obj: Any, save_state: SaveState): + # This method is used to persist bound methods, which are + # dependent on a specific instance of an object. + # It stores the state of the object the method is bound to, + # and prepares both to be persisted. + res = { + "__class__": obj.__class__.__name__, + "__module__": get_module(obj), + "__loader__": "method_get_instance", + "content": { + "func": obj.__func__.__name__, + "obj": get_state(obj.__self__, save_state), + }, + } + + return res + + +def method_get_instance(state, src): + loaded_obj = object_get_instance(state["content"]["obj"], src) + method = getattr(loaded_obj, state["content"]["func"]) + return method + + def unsupported_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: raise UnsupportedTypeException(obj) @@ -253,6 +277,7 @@ def unsupported_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: (tuple, tuple_get_state), (slice, slice_get_state), (FunctionType, function_get_state), + (MethodType, method_get_state), (partial, partial_get_state), (type, type_get_state), (object, object_get_state), @@ -264,6 +289,7 @@ def unsupported_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: "tuple_get_instance": tuple_get_instance, "slice_get_instance": slice_get_instance, "function_get_instance": function_get_instance, + "method_get_instance": method_get_instance, "partial_get_instance": partial_get_instance, "type_get_instance": type_get_instance, "object_get_instance": object_get_instance, diff --git a/skops/io/_utils.py b/skops/io/_utils.py index 80c06aa1..2b6f9749 100644 --- a/skops/io/_utils.py +++ b/skops/io/_utils.py @@ -5,7 +5,6 @@ import sys from dataclasses import dataclass, field from functools import singledispatch -from types import FunctionType from typing import Any from zipfile import ZipFile @@ -61,10 +60,6 @@ def _import_obj(module, cls_or_func, package=None): def gettype(state): if "__module__" in state and "__class__" in state: - if state["__class__"] == "function": - # This special case is due to how functions are serialized. We - # could try to change it. - return FunctionType return _import_obj(state["__module__"], state["__class__"]) return None diff --git a/skops/io/tests/test_persist.py b/skops/io/tests/test_persist.py index 1045e048..b999487b 100644 --- a/skops/io/tests/test_persist.py +++ b/skops/io/tests/test_persist.py @@ -792,6 +792,106 @@ def test_get_instance_unknown_type_error_msg(): get_instance(state, None) +class _BoundMethodHolder: + """Used to test the ability to serialize and deserialize bound methods""" + + def __init__(self, object_state: str): + # Initialize with some state to make sure state is persisted + self.object_state = object_state + # bind some method to this object, could be any persistable function + self.chosen_function = np.log + + def bound_method(self, x): + return self.chosen_function(x) + + def other_bound_method(self, x): + # arbitrary other function, used for checking single instance loaded + return self.chosen_function(x) + + +class TestPersistingBoundMethods: + @staticmethod + def assert_transformer_persisted_correctly( + loaded_transformer: FunctionTransformer, + original_transformer: FunctionTransformer, + ): + """Checks that a persisted and original transformer are equivalent, including + the func passed to it + """ + assert loaded_transformer.func.__name__ == original_transformer.func.__name__ + + assert_params_equal( + loaded_transformer.func.__self__.__dict__, + original_transformer.func.__self__.__dict__, + ) + assert_params_equal(loaded_transformer.__dict__, original_transformer.__dict__) + + @staticmethod + def assert_bound_method_holder_persisted_correctly( + original_obj: _BoundMethodHolder, loaded_obj: _BoundMethodHolder + ): + """Checks that the persisted and original instances of _BoundMethodHolder are + equivalent + """ + assert original_obj.bound_method.__name__ == loaded_obj.bound_method.__name__ + assert original_obj.chosen_function == loaded_obj.chosen_function + + assert_params_equal(original_obj.__dict__, loaded_obj.__dict__) + + def test_for_base_case_returns_as_expected(self): + initial_state = "This is an arbitrary state" + obj = _BoundMethodHolder(object_state=initial_state) + bound_function = obj.bound_method + transformer = FunctionTransformer(func=bound_function) + + loaded_transformer = loads(dumps(transformer)) + loaded_obj = loaded_transformer.func.__self__ + + self.assert_transformer_persisted_correctly(loaded_transformer, transformer) + self.assert_bound_method_holder_persisted_correctly(obj, loaded_obj) + + def test_when_object_is_changed_after_init_works_as_expected(self): + # given change to object with bound method after initialisation, + # make sure still persists correctly + + initial_state = "This is an arbitrary state" + obj = _BoundMethodHolder(object_state=initial_state) + obj.chosen_function = np.sqrt + bound_function = obj.bound_method + + transformer = FunctionTransformer(func=bound_function) + + loaded_transformer = loads(dumps(transformer)) + loaded_obj = loaded_transformer.func.__self__ + + self.assert_transformer_persisted_correctly(loaded_transformer, transformer) + self.assert_bound_method_holder_persisted_correctly(obj, loaded_obj) + + @pytest.mark.xfail( + reason="Can't load an object as a single instance if referenced multiple times" + ) + def test_works_when_given_multiple_bound_methods_attached_to_single_instance(self): + obj = _BoundMethodHolder(object_state="") + + transformer = FunctionTransformer( + func=obj.bound_method, inverse_func=obj.other_bound_method + ) + + loaded_transformer = loads(dumps(transformer)) + + # check that both func and inverse_func are from the same object instance + loaded_0 = loaded_transformer.func.__self__ + loaded_1 = loaded_transformer.inverse_func.__self__ + assert loaded_0 is loaded_1 + + @pytest.mark.xfail(reason="Failing due to circular self reference") + def test_scipy_stats(self, tmp_path): + from scipy import stats + + estimator = FunctionTransformer(func=stats.zipf) + loads(dumps(estimator)) + + class CustomEstimator(BaseEstimator): """Estimator with np array, np scalar, and sparse matrix attribute"""