diff --git a/docs/changes.rst b/docs/changes.rst index 7addfd3b..f5e13aeb 100644 --- a/docs/changes.rst +++ b/docs/changes.rst @@ -9,8 +9,13 @@ skops Changelog :depth: 1 :local: +v0.11 +----- +- Correctly restore ``default_factory`` when saving and loading a ``defaultdict``. + :pr:`433` by `Adrin Jalali`_. + v0.10 ----- +----- - Removes Pythn 3.8 support and adds Python 3.12 Support :pr:`418` by :user:`Thomas Lazarus `. - Removes a shortcut to add `sklearn-intelex` as a not dependency. :pr:`420` by :user:`Thomas Lazarus < lazarust > `. diff --git a/skops/io/_general.py b/skops/io/_general.py index 46a3b1c8..bc25e7c2 100644 --- a/skops/io/_general.py +++ b/skops/io/_general.py @@ -4,6 +4,7 @@ import json import operator import uuid +from collections import defaultdict from functools import partial from reprlib import Repr from types import FunctionType, MethodType @@ -14,6 +15,7 @@ from ._audit import Node, get_tree from ._protocol import PROTOCOL from ._trusted_types import ( + CONTAINER_TYPE_NAMES, NUMPY_DTYPE_TYPE_NAMES, NUMPY_UFUNC_TYPE_NAMES, PRIMITIVE_TYPE_NAMES, @@ -63,7 +65,7 @@ def __init__( trusted: Optional[Sequence[str]] = None, ) -> None: super().__init__(state, load_context, trusted) - self.trusted = self._get_trusted(trusted, [dict]) + self.trusted = self._get_trusted(trusted, [dict, "collections.OrderedDict"]) self.children = { "key_types": get_tree(state["key_types"], load_context, trusted=trusted), "content": { @@ -80,6 +82,45 @@ def _construct(self): return content +def defaultdict_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: + res = { + "__class__": obj.__class__.__name__, + "__module__": get_module(type(obj)), + "__loader__": "DefaultDictNode", + } + content = {} + # explicitly pass a dict object instead of _DictWithDeprecatedKeys and + # later construct a _DictWithDeprecatedKeys object. + content["main"] = get_state(dict(obj), save_context) + content["default_factory"] = get_state(obj.default_factory, save_context) + res["content"] = content + return res + + +class DefaultDictNode(Node): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: Optional[Sequence[str]] = None, + ) -> None: + super().__init__(state, load_context, trusted) + self.trusted = ["collections.defaultdict"] + self.children = { + "main": get_tree(state["content"]["main"], load_context, trusted=trusted), + "default_factory": get_tree( + state["content"]["default_factory"], + load_context, + trusted=trusted, + ), + } + + def _construct(self): + instance = defaultdict(**self.children["main"].construct()) + instance.default_factory = self.children["default_factory"].construct() + return instance + + def list_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, @@ -298,7 +339,8 @@ def __init__( super().__init__(state, load_context, trusted) # TODO: what do we trust? self.trusted = self._get_trusted( - trusted, PRIMITIVE_TYPE_NAMES + NUMPY_DTYPE_TYPE_NAMES + trusted, + PRIMITIVE_TYPE_NAMES + CONTAINER_TYPE_NAMES + NUMPY_DTYPE_TYPE_NAMES, ) # We use a bare Node type here since a Node only checks the type in the # dict using __class__ and __module__ keys. @@ -597,6 +639,7 @@ def _construct(self): # tuples of type and function that gets the state of that type GET_STATE_DISPATCH_FUNCTIONS = [ (dict, dict_get_state), + (defaultdict, defaultdict_get_state), (list, list_get_state), (set, set_get_state), (tuple, tuple_get_state), @@ -616,6 +659,7 @@ def _construct(self): NODE_TYPE_MAPPING = { ("DictNode", PROTOCOL): DictNode, + ("DefaultDictNode", PROTOCOL): DefaultDictNode, ("ListNode", PROTOCOL): ListNode, ("SetNode", PROTOCOL): SetNode, ("TupleNode", PROTOCOL): TupleNode, diff --git a/skops/io/_trusted_types.py b/skops/io/_trusted_types.py index eebd7d37..e5bee6b8 100644 --- a/skops/io/_trusted_types.py +++ b/skops/io/_trusted_types.py @@ -10,6 +10,10 @@ PRIMITIVE_TYPE_NAMES = ["builtins." + t.__name__ for t in PRIMITIVES_TYPES] +CONTAINER_TYPES = [list, set, map, tuple] + +CONTAINER_TYPE_NAMES = ["builtins." + t.__name__ for t in CONTAINER_TYPES] + SKLEARN_ESTIMATOR_TYPE_NAMES = [ get_type_name(estimator_class) for _, estimator_class in all_estimators() diff --git a/skops/io/tests/test_external.py b/skops/io/tests/test_external.py index a5d6c5ae..d9fa7916 100644 --- a/skops/io/tests/test_external.py +++ b/skops/io/tests/test_external.py @@ -76,7 +76,7 @@ def lgbm(self): def trusted(self): # TODO: adjust once more types are trusted by default return [ - "collections.defaultdict", + "collections.OrderedDict", "lightgbm.basic.Booster", "lightgbm.sklearn.LGBMClassifier", "lightgbm.sklearn.LGBMRegressor", diff --git a/skops/io/tests/test_persist.py b/skops/io/tests/test_persist.py index 9c0b7a57..34b4dae6 100644 --- a/skops/io/tests/test_persist.py +++ b/skops/io/tests/test_persist.py @@ -5,7 +5,7 @@ import operator import sys import warnings -from collections import Counter +from collections import Counter, OrderedDict, defaultdict from functools import partial, wraps from pathlib import Path from zipfile import ZIP_DEFLATED, ZipFile @@ -56,6 +56,7 @@ from skops.io._audit import NODE_TYPE_MAPPING, get_tree from skops.io._sklearn import UNSUPPORTED_TYPES from skops.io._trusted_types import ( + CONTAINER_TYPE_NAMES, NUMPY_DTYPE_TYPE_NAMES, NUMPY_UFUNC_TYPE_NAMES, PRIMITIVE_TYPE_NAMES, @@ -247,7 +248,9 @@ def _tested_ufuncs(): def _tested_types(): - for full_name in PRIMITIVE_TYPE_NAMES + NUMPY_DTYPE_TYPE_NAMES: + for full_name in ( + PRIMITIVE_TYPE_NAMES + NUMPY_DTYPE_TYPE_NAMES + CONTAINER_TYPE_NAMES + ): module_name, _, type_name = full_name.rpartition(".") yield gettype(module_name=module_name, cls_or_func=type_name) @@ -396,7 +399,9 @@ def test_can_trust_ufuncs(ufunc): @pytest.mark.parametrize( - "type_", _tested_types(), ids=PRIMITIVE_TYPE_NAMES + NUMPY_DTYPE_TYPE_NAMES + "type_", + _tested_types(), + ids=PRIMITIVE_TYPE_NAMES + NUMPY_DTYPE_TYPE_NAMES + CONTAINER_TYPE_NAMES, ) def test_can_trust_types(type_): dumped = dumps(type_) @@ -1078,3 +1083,20 @@ def test_trusted_bool_raises(tmp_path): with pytest.raises(TypeError, match="trusted must be a list of strings"): loads(dumps(10), trusted=True) # type: ignore + + +def test_defaultdict(): + """Test that we correctly restore a defaultdict.""" + obj = defaultdict(set) + obj["foo"] = "bar" + obj_loaded = loads(dumps(obj)) + assert obj_loaded == obj + assert obj_loaded.default_factory == obj.default_factory + + +@pytest.mark.parametrize("cls", [dict, OrderedDict]) +def test_dictionary(cls): + obj = cls({1: 5, 6: 3, 2: 4}) + loaded_obj = loads(dumps(obj)) + assert obj == loaded_obj + assert type(obj) is cls