From c5336ed7d30fb78eea08944a475503c4a991495e Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Thu, 11 Jul 2024 15:37:58 +0200 Subject: [PATCH 1/6] ENH correctly restore default_factory of a defaultdict --- skops/io/_general.py | 45 +++++++++++++++++++++++++++++++++- skops/io/_trusted_types.py | 4 +++ skops/io/tests/test_persist.py | 18 +++++++++++--- 3 files changed, 63 insertions(+), 4 deletions(-) diff --git a/skops/io/_general.py b/skops/io/_general.py index 46a3b1c8..105cf3f4 100644 --- a/skops/io/_general.py +++ b/skops/io/_general.py @@ -4,6 +4,7 @@ import json import operator import uuid +from collections import defaultdict from functools import partial from reprlib import Repr from types import FunctionType, MethodType @@ -14,6 +15,7 @@ from ._audit import Node, get_tree from ._protocol import PROTOCOL from ._trusted_types import ( + BUILTIN_TYPE_NAMES, NUMPY_DTYPE_TYPE_NAMES, NUMPY_UFUNC_TYPE_NAMES, PRIMITIVE_TYPE_NAMES, @@ -80,6 +82,45 @@ def _construct(self): return content +def defaultdict_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: + res = { + "__class__": obj.__class__.__name__, + "__module__": get_module(type(obj)), + "__loader__": "DefaultDictNode", + } + content = {} + # explicitly pass a dict object instead of _DictWithDeprecatedKeys and + # later construct a _DictWithDeprecatedKeys object. + content["main"] = get_state(dict(obj), save_context) + content["default_factory"] = get_state(obj.default_factory, save_context) + res["content"] = content + return res + + +class DefaultDictNode(Node): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: Optional[Sequence[str]] = None, + ) -> None: + super().__init__(state, load_context, trusted) + self.trusted = ["collections.defaultdict"] + self.children = { + "main": get_tree(state["content"]["main"], load_context, trusted=trusted), + "default_factory": get_tree( + state["content"]["default_factory"], + load_context, + trusted=trusted, + ), + } + + def _construct(self): + instance = defaultdict(**self.children["main"].construct()) + instance.default_factory = self.children["default_factory"].construct() + return instance + + def list_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: res = { "__class__": obj.__class__.__name__, @@ -298,7 +339,7 @@ def __init__( super().__init__(state, load_context, trusted) # TODO: what do we trust? self.trusted = self._get_trusted( - trusted, PRIMITIVE_TYPE_NAMES + NUMPY_DTYPE_TYPE_NAMES + trusted, PRIMITIVE_TYPE_NAMES + BUILTIN_TYPE_NAMES + NUMPY_DTYPE_TYPE_NAMES ) # We use a bare Node type here since a Node only checks the type in the # dict using __class__ and __module__ keys. @@ -597,6 +638,7 @@ def _construct(self): # tuples of type and function that gets the state of that type GET_STATE_DISPATCH_FUNCTIONS = [ (dict, dict_get_state), + (defaultdict, defaultdict_get_state), (list, list_get_state), (set, set_get_state), (tuple, tuple_get_state), @@ -616,6 +658,7 @@ def _construct(self): NODE_TYPE_MAPPING = { ("DictNode", PROTOCOL): DictNode, + ("DefaultDictNode", PROTOCOL): DefaultDictNode, ("ListNode", PROTOCOL): ListNode, ("SetNode", PROTOCOL): SetNode, ("TupleNode", PROTOCOL): TupleNode, diff --git a/skops/io/_trusted_types.py b/skops/io/_trusted_types.py index eebd7d37..6d35eeb3 100644 --- a/skops/io/_trusted_types.py +++ b/skops/io/_trusted_types.py @@ -10,6 +10,10 @@ PRIMITIVE_TYPE_NAMES = ["builtins." + t.__name__ for t in PRIMITIVES_TYPES] +BUILTIN_TYPES = [list, set, map, tuple] + +BUILTIN_TYPE_NAMES = ["builtins." + t.__name__ for t in BUILTIN_TYPES] + SKLEARN_ESTIMATOR_TYPE_NAMES = [ get_type_name(estimator_class) for _, estimator_class in all_estimators() diff --git a/skops/io/tests/test_persist.py b/skops/io/tests/test_persist.py index 9c0b7a57..c57b00dd 100644 --- a/skops/io/tests/test_persist.py +++ b/skops/io/tests/test_persist.py @@ -5,7 +5,7 @@ import operator import sys import warnings -from collections import Counter +from collections import Counter, defaultdict from functools import partial, wraps from pathlib import Path from zipfile import ZIP_DEFLATED, ZipFile @@ -56,6 +56,7 @@ from skops.io._audit import NODE_TYPE_MAPPING, get_tree from skops.io._sklearn import UNSUPPORTED_TYPES from skops.io._trusted_types import ( + BUILTIN_TYPE_NAMES, NUMPY_DTYPE_TYPE_NAMES, NUMPY_UFUNC_TYPE_NAMES, PRIMITIVE_TYPE_NAMES, @@ -247,7 +248,7 @@ def _tested_ufuncs(): def _tested_types(): - for full_name in PRIMITIVE_TYPE_NAMES + NUMPY_DTYPE_TYPE_NAMES: + for full_name in PRIMITIVE_TYPE_NAMES + NUMPY_DTYPE_TYPE_NAMES + BUILTIN_TYPE_NAMES: module_name, _, type_name = full_name.rpartition(".") yield gettype(module_name=module_name, cls_or_func=type_name) @@ -396,7 +397,9 @@ def test_can_trust_ufuncs(ufunc): @pytest.mark.parametrize( - "type_", _tested_types(), ids=PRIMITIVE_TYPE_NAMES + NUMPY_DTYPE_TYPE_NAMES + "type_", + _tested_types(), + ids=PRIMITIVE_TYPE_NAMES + NUMPY_DTYPE_TYPE_NAMES + BUILTIN_TYPE_NAMES, ) def test_can_trust_types(type_): dumped = dumps(type_) @@ -1078,3 +1081,12 @@ def test_trusted_bool_raises(tmp_path): with pytest.raises(TypeError, match="trusted must be a list of strings"): loads(dumps(10), trusted=True) # type: ignore + + +def test_defaultdict(): + """Test that we correctly restore a defaultdict.""" + obj = defaultdict(set) + obj["foo"] = "bar" + obj_loaded = loads(dumps(obj)) + assert obj_loaded == obj + assert obj_loaded.default_factory == obj.default_factory From ca649aeca062a7da517d86e2c9a270ce64e4109a Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Thu, 11 Jul 2024 15:40:56 +0200 Subject: [PATCH 2/6] MNT add changelog for new version and bump version --- docs/changes.rst | 7 ++++++- skops/__init__.py | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/docs/changes.rst b/docs/changes.rst index 7addfd3b..f5e13aeb 100644 --- a/docs/changes.rst +++ b/docs/changes.rst @@ -9,8 +9,13 @@ skops Changelog :depth: 1 :local: +v0.11 +----- +- Correctly restore ``default_factory`` when saving and loading a ``defaultdict``. + :pr:`433` by `Adrin Jalali`_. + v0.10 ----- +----- - Removes Pythn 3.8 support and adds Python 3.12 Support :pr:`418` by :user:`Thomas Lazarus `. - Removes a shortcut to add `sklearn-intelex` as a not dependency. :pr:`420` by :user:`Thomas Lazarus < lazarust > `. diff --git a/skops/__init__.py b/skops/__init__.py index 9d0ad282..ebd63a0e 100644 --- a/skops/__init__.py +++ b/skops/__init__.py @@ -16,7 +16,7 @@ # Dev branch marker is: 'X.Y.dev' or 'X.Y.devN' where N is an integer. # 'X.Y.dev0' is the canonical version of 'X.Y.dev' # -__version__ = "0.10.dev0" +__version__ = "0.11.dev0" try: # This variable is injected in the __builtins__ by the build From 1bb275c59c0b77de082469325fd29fe267b38f13 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Fri, 12 Jul 2024 13:11:18 +0200 Subject: [PATCH 3/6] add OrderedDict trusted --- skops/io/_general.py | 2 +- skops/io/tests/test_external.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/skops/io/_general.py b/skops/io/_general.py index 105cf3f4..d426a36f 100644 --- a/skops/io/_general.py +++ b/skops/io/_general.py @@ -65,7 +65,7 @@ def __init__( trusted: Optional[Sequence[str]] = None, ) -> None: super().__init__(state, load_context, trusted) - self.trusted = self._get_trusted(trusted, [dict]) + self.trusted = self._get_trusted(trusted, [dict, "collections.OrderedDict"]) self.children = { "key_types": get_tree(state["key_types"], load_context, trusted=trusted), "content": { diff --git a/skops/io/tests/test_external.py b/skops/io/tests/test_external.py index a5d6c5ae..d9fa7916 100644 --- a/skops/io/tests/test_external.py +++ b/skops/io/tests/test_external.py @@ -76,7 +76,7 @@ def lgbm(self): def trusted(self): # TODO: adjust once more types are trusted by default return [ - "collections.defaultdict", + "collections.OrderedDict", "lightgbm.basic.Booster", "lightgbm.sklearn.LGBMClassifier", "lightgbm.sklearn.LGBMRegressor", From cfab7779769a78e0ad1a398cea0187f77275b370 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Mon, 22 Jul 2024 16:18:35 +0200 Subject: [PATCH 4/6] rename builtin type names to container type names --- skops/io/_general.py | 5 +++-- skops/io/_trusted_types.py | 4 ++-- skops/io/tests/test_persist.py | 8 +++++--- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/skops/io/_general.py b/skops/io/_general.py index d426a36f..bc25e7c2 100644 --- a/skops/io/_general.py +++ b/skops/io/_general.py @@ -15,7 +15,7 @@ from ._audit import Node, get_tree from ._protocol import PROTOCOL from ._trusted_types import ( - BUILTIN_TYPE_NAMES, + CONTAINER_TYPE_NAMES, NUMPY_DTYPE_TYPE_NAMES, NUMPY_UFUNC_TYPE_NAMES, PRIMITIVE_TYPE_NAMES, @@ -339,7 +339,8 @@ def __init__( super().__init__(state, load_context, trusted) # TODO: what do we trust? self.trusted = self._get_trusted( - trusted, PRIMITIVE_TYPE_NAMES + BUILTIN_TYPE_NAMES + NUMPY_DTYPE_TYPE_NAMES + trusted, + PRIMITIVE_TYPE_NAMES + CONTAINER_TYPE_NAMES + NUMPY_DTYPE_TYPE_NAMES, ) # We use a bare Node type here since a Node only checks the type in the # dict using __class__ and __module__ keys. diff --git a/skops/io/_trusted_types.py b/skops/io/_trusted_types.py index 6d35eeb3..e5bee6b8 100644 --- a/skops/io/_trusted_types.py +++ b/skops/io/_trusted_types.py @@ -10,9 +10,9 @@ PRIMITIVE_TYPE_NAMES = ["builtins." + t.__name__ for t in PRIMITIVES_TYPES] -BUILTIN_TYPES = [list, set, map, tuple] +CONTAINER_TYPES = [list, set, map, tuple] -BUILTIN_TYPE_NAMES = ["builtins." + t.__name__ for t in BUILTIN_TYPES] +CONTAINER_TYPE_NAMES = ["builtins." + t.__name__ for t in CONTAINER_TYPES] SKLEARN_ESTIMATOR_TYPE_NAMES = [ get_type_name(estimator_class) diff --git a/skops/io/tests/test_persist.py b/skops/io/tests/test_persist.py index c57b00dd..c6048bd4 100644 --- a/skops/io/tests/test_persist.py +++ b/skops/io/tests/test_persist.py @@ -56,7 +56,7 @@ from skops.io._audit import NODE_TYPE_MAPPING, get_tree from skops.io._sklearn import UNSUPPORTED_TYPES from skops.io._trusted_types import ( - BUILTIN_TYPE_NAMES, + CONTAINER_TYPE_NAMES, NUMPY_DTYPE_TYPE_NAMES, NUMPY_UFUNC_TYPE_NAMES, PRIMITIVE_TYPE_NAMES, @@ -248,7 +248,9 @@ def _tested_ufuncs(): def _tested_types(): - for full_name in PRIMITIVE_TYPE_NAMES + NUMPY_DTYPE_TYPE_NAMES + BUILTIN_TYPE_NAMES: + for full_name in ( + PRIMITIVE_TYPE_NAMES + NUMPY_DTYPE_TYPE_NAMES + CONTAINER_TYPE_NAMES + ): module_name, _, type_name = full_name.rpartition(".") yield gettype(module_name=module_name, cls_or_func=type_name) @@ -399,7 +401,7 @@ def test_can_trust_ufuncs(ufunc): @pytest.mark.parametrize( "type_", _tested_types(), - ids=PRIMITIVE_TYPE_NAMES + NUMPY_DTYPE_TYPE_NAMES + BUILTIN_TYPE_NAMES, + ids=PRIMITIVE_TYPE_NAMES + NUMPY_DTYPE_TYPE_NAMES + CONTAINER_TYPE_NAMES, ) def test_can_trust_types(type_): dumped = dumps(type_) From f0a0e0588e3ed2b8c2ac02c23d328498f31293a7 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Mon, 22 Jul 2024 16:23:54 +0200 Subject: [PATCH 5/6] add test for OrderedDict --- skops/io/tests/test_persist.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/skops/io/tests/test_persist.py b/skops/io/tests/test_persist.py index c6048bd4..3bb10b51 100644 --- a/skops/io/tests/test_persist.py +++ b/skops/io/tests/test_persist.py @@ -5,7 +5,7 @@ import operator import sys import warnings -from collections import Counter, defaultdict +from collections import Counter, OrderedDict, defaultdict from functools import partial, wraps from pathlib import Path from zipfile import ZIP_DEFLATED, ZipFile @@ -1092,3 +1092,10 @@ def test_defaultdict(): obj_loaded = loads(dumps(obj)) assert obj_loaded == obj assert obj_loaded.default_factory == obj.default_factory + + +@pytest.mark.parametrize("cls", [dict, OrderedDict]) +def test_dictionary(cls): + obj = cls({1: 5, 6: 3, 2: 4}) + loaded_obj = loads(dumps(obj)) + assert obj == loaded_obj From d1185207bf6798c481617e78dd59b0ecd3735145 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Thu, 8 Aug 2024 14:42:21 +0200 Subject: [PATCH 6/6] test type as well --- skops/io/tests/test_persist.py | 1 + 1 file changed, 1 insertion(+) diff --git a/skops/io/tests/test_persist.py b/skops/io/tests/test_persist.py index 3bb10b51..34b4dae6 100644 --- a/skops/io/tests/test_persist.py +++ b/skops/io/tests/test_persist.py @@ -1099,3 +1099,4 @@ def test_dictionary(cls): obj = cls({1: 5, 6: 3, 2: 4}) loaded_obj = loads(dumps(obj)) assert obj == loaded_obj + assert type(obj) is cls