From 892b3fdcc083689d720abae04266ffaa2c7004f1 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Tue, 28 Mar 2023 11:23:24 +0200 Subject: [PATCH 1/3] FIX: Bug in random generator persistence One of its children was not persisted correctly (pure dict instead of DictNode), which made it impossible to check the untrusted types on it. This fix correctly stores the state as a node. Tests were extended to uncover this bug. The old persistence scheme was preserved for protocol vesion 0 -- even though it is buggy, it can work under some circumstances. --- skops/io/_numpy.py | 18 +++++++---- skops/io/_persist.py | 2 +- skops/io/old/_numpy_v0.py | 37 +++++++++++++++++++++++ skops/io/tests/_utils.py | 29 +++++++++++------- skops/io/tests/test_persist.py | 9 +++++- skops/io/tests/test_persist_old.py | 48 ++++++++++++++++++++++++++++++ 6 files changed, 124 insertions(+), 19 deletions(-) create mode 100644 skops/io/old/_numpy_v0.py diff --git a/skops/io/_numpy.py b/skops/io/_numpy.py index 1f2cef82..d5243f15 100644 --- a/skops/io/_numpy.py +++ b/skops/io/_numpy.py @@ -165,7 +165,7 @@ def _construct(self): def random_generator_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: - bit_generator_state = obj.bit_generator.state + bit_generator_state = get_state(obj.bit_generator.state, save_context) res = { "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), @@ -183,15 +183,21 @@ def __init__( trusted: bool | Sequence[str] = False, ) -> None: super().__init__(state, load_context, trusted) - self.children = {"bit_generator_state": state["content"]["bit_generator"]} + self.children = { + "bit_generator_state": get_tree( + state["content"]["bit_generator"], load_context + ) + } self.trusted = self._get_trusted(trusted, [np.random.Generator]) def _construct(self): # first restore the state of the bit generator - bit_generator = gettype( - "numpy.random", self.children["bit_generator_state"]["bit_generator"] - )() - bit_generator.state = self.children["bit_generator_state"] + bit_generator_state = self.children["bit_generator_state"].construct() + bit_generator_cls = gettype( + "numpy.random", bit_generator_state["bit_generator"] + ) + bit_generator = bit_generator_cls() + bit_generator.state = bit_generator_state # next create the generator instance return gettype(self.module_name, self.class_name)(bit_generator=bit_generator) diff --git a/skops/io/_persist.py b/skops/io/_persist.py index 12796e05..a90cb1fc 100644 --- a/skops/io/_persist.py +++ b/skops/io/_persist.py @@ -16,7 +16,7 @@ # them. Old protocols are found in the 'old/' directory, with the protocol # version appended to the corresponding module name. modules = ["._general", "._numpy", "._scipy", "._sklearn"] -modules.extend([".old._general_v0"]) +modules.extend([".old._general_v0", ".old._numpy_v0"]) for module_name in modules: # register exposed functions for get_state and get_tree module = importlib.import_module(module_name, package="skops.io") diff --git a/skops/io/old/_numpy_v0.py b/skops/io/old/_numpy_v0.py new file mode 100644 index 00000000..04e258ee --- /dev/null +++ b/skops/io/old/_numpy_v0.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from typing import Any, Sequence + +import numpy as np + +from skops.io._audit import Node +from skops.io._utils import LoadContext, gettype + +PROTOCOL = 0 + + +class RandomGeneratorNode(Node): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: bool | Sequence[str] = False, + ) -> None: + super().__init__(state, load_context, trusted) + self.children = {"bit_generator_state": state["content"]["bit_generator"]} + self.trusted = self._get_trusted(trusted, [np.random.Generator]) + + def _construct(self): + # first restore the state of the bit generator + bit_generator = gettype( + "numpy.random", self.children["bit_generator_state"]["bit_generator"] + )() + bit_generator.state = self.children["bit_generator_state"] + + # next create the generator instance + return gettype(self.module_name, self.class_name)(bit_generator=bit_generator) + + +NODE_TYPE_MAPPING = { + ("RandomGeneratorNode", PROTOCOL): RandomGeneratorNode, +} diff --git a/skops/io/tests/_utils.py b/skops/io/tests/_utils.py index ce9c45d9..d355f959 100644 --- a/skops/io/tests/_utils.py +++ b/skops/io/tests/_utils.py @@ -219,9 +219,11 @@ def downgrade_state(*, data: bytes, keys: list[str], old_state: dict, protocol: data : bytes The old state, as generated by ``skops.io.dumps``. - keys : list of str + keys : list of str, or None The keys that lead to the old state. E.g. if we want to replace - ``state["foo"]["bar"]``, then keys should be ``["foo", "bar"]``. + ``state["foo"]["bar"]``, then keys should be ``["foo", "bar"]``. If + ``keys=None``, the whole schema is instead replaced by the old state + being passed. old_state : dict The old state, as would be produced by the old ``get_state`` function. @@ -239,17 +241,22 @@ def downgrade_state(*, data: bytes, keys: list[str], old_state: dict, protocol: with ZipFile(io.BytesIO(data), "r") as zip_file: schema = json.loads(zip_file.read("schema.json")) - # replace schema - schema["protocol"] = protocol + # replace state in schema using old state + if keys is None: + # replace all fields + schema = old_state + schema["__id__"] = id(schema) + else: + # replace specific field + state = schema + for key in keys[:-1]: + state = state[key] + state[keys[-1]] = old_state - # replace state using old state - state = schema - for key in keys[:-1]: - state = state[key] - state[keys[-1]] = old_state + # there has to be an __id__ field for memoization + state[keys[-1]]["__id__"] = id(schema) - # there has to be an __id__ field for memoization - state[keys[-1]]["__id__"] = id(schema) + schema["protocol"] = protocol # dump into bytes buffer = io.BytesIO() diff --git a/skops/io/tests/test_persist.py b/skops/io/tests/test_persist.py index 19a6c7eb..1724a4c1 100644 --- a/skops/io/tests/test_persist.py +++ b/skops/io/tests/test_persist.py @@ -421,7 +421,14 @@ def test_random_state(random_state): est = RandomStateEstimator(random_state=random_state).fit(None, None) est.random_state_.random(123) # move RNG forwards - loaded = loads(dumps(est), trusted=True) + dumped = dumps(est) + untrusted_types = get_untrusted_types(data=dumped) + loaded = loads(dumped, trusted=untrusted_types) + + if hasattr(est, "__dict__"): + # what to do if object has no __dict__, like Generator? + assert_params_equal(est.__dict__, loaded.__dict__) + rand_floats_expected = est.random_state_.random(100) rand_floats_loaded = loaded.random_state_.random(100) np.testing.assert_equal(rand_floats_loaded, rand_floats_expected) diff --git a/skops/io/tests/test_persist_old.py b/skops/io/tests/test_persist_old.py index 69f297fa..504b16c0 100644 --- a/skops/io/tests/test_persist_old.py +++ b/skops/io/tests/test_persist_old.py @@ -66,3 +66,51 @@ def old_function_get_state(obj, save_context): # check that loaded estimator is identical assert_params_equal(estimator.__dict__, loaded.__dict__) assert_method_outputs_equal(estimator, loaded, X) + + +@pytest.mark.parametrize( + "rng", + [ + np.random.default_rng(), + np.random.Generator(np.random.PCG64DXSM(seed=123)), + ], + ids=["default_rng", "Generator"], +) +def test_random_generator_v0(rng): + call_count = 0 + + # random_generator_get_state as it was for protocol 0 + def old_random_generator_get_state(obj, save_context): + # added for testing + nonlocal call_count + call_count += 1 + # end + + bit_generator_state = obj.bit_generator.state + res = { + "__class__": obj.__class__.__name__, + "__module__": get_module(type(obj)), + "__loader__": "RandomGeneratorNode", + "content": {"bit_generator": bit_generator_state}, + } + return res + + rng.random(123) # move RNG forwards + dumped = dumps(rng) + # importent: downgrade the whole state to mimic older version + downgraded = downgrade_state( + data=dumped, + keys=None, + old_state=old_random_generator_get_state(rng, None), + protocol=0, + ) + + # old loader only worked with trusted=True, see #329 + loaded = loads(downgraded, trusted=True) + + # sanity check: ensure that the old get_state function was really called + assert call_count == 1 + + rand_floats_expected = rng.random(100) + rand_floats_loaded = loaded.random(100) + np.testing.assert_equal(rand_floats_loaded, rand_floats_expected) From 79ed16075f5da159b32ce89001a046d48d4e7779 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Tue, 28 Mar 2023 15:31:53 +0200 Subject: [PATCH 2/3] Add changes.rst entry for this fix --- docs/changes.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/changes.rst b/docs/changes.rst index bfcadbc2..6de65f26 100644 --- a/docs/changes.rst +++ b/docs/changes.rst @@ -28,6 +28,9 @@ v0.6 - Add possibility to visualize a skops object and show untrusted types by using :func:`skops.io.visualize`. For colored output, install `rich`: `pip install rich`. :pr:`317` by `Benjamin Bossan`_. +- Fix issue with persisting :class:`numpy.random.Generator` using the skops + format (the object could be loaded correctly but security could not be + checked). :pr:`331` by `Benjamin Bossan`_. v0.5 ---- From 429a5a73630c07e64b42b1e1d1e6cf63b62bf9de Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Tue, 28 Mar 2023 15:32:12 +0200 Subject: [PATCH 3/3] Normalize tense of changes.rst entries --- docs/changes.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/changes.rst b/docs/changes.rst index 6de65f26..95c446f3 100644 --- a/docs/changes.rst +++ b/docs/changes.rst @@ -11,7 +11,7 @@ skops Changelog v0.6 ---- -- Added tabular regression example. :pr:`254` by :user:`Thomas Lazarus `. +- Add tabular regression example. :pr:`254` by :user:`Thomas Lazarus `. - All public ``scipy.special`` ufuncs (Universal Functions) are trusted by default by :func:`.io.load`. :pr:`295` by :user:`Omar Arab Oghli `. - Add a new function :func:`skops.card.Card.add_metric_frame` to help users @@ -34,7 +34,7 @@ v0.6 v0.5 ---- -- Added CLI entrypoint support (:func:`.cli.entrypoint.main_cli`) +- Add CLI entrypoint support (:func:`.cli.entrypoint.main_cli`) and a command line function to convert Pickle files to Skops files (:func:`.cli._convert.main`). :pr:`249` by `Erin Aho`_ - Support more array-like data types for tabular data and list-like data types