diff --git a/docs/changes.rst b/docs/changes.rst index bfcadbc2..95c446f3 100644 --- a/docs/changes.rst +++ b/docs/changes.rst @@ -11,7 +11,7 @@ skops Changelog v0.6 ---- -- Added tabular regression example. :pr:`254` by :user:`Thomas Lazarus `. +- Add tabular regression example. :pr:`254` by :user:`Thomas Lazarus `. - All public ``scipy.special`` ufuncs (Universal Functions) are trusted by default by :func:`.io.load`. :pr:`295` by :user:`Omar Arab Oghli `. - Add a new function :func:`skops.card.Card.add_metric_frame` to help users @@ -28,10 +28,13 @@ v0.6 - Add possibility to visualize a skops object and show untrusted types by using :func:`skops.io.visualize`. For colored output, install `rich`: `pip install rich`. :pr:`317` by `Benjamin Bossan`_. +- Fix issue with persisting :class:`numpy.random.Generator` using the skops + format (the object could be loaded correctly but security could not be + checked). :pr:`331` by `Benjamin Bossan`_. v0.5 ---- -- Added CLI entrypoint support (:func:`.cli.entrypoint.main_cli`) +- Add CLI entrypoint support (:func:`.cli.entrypoint.main_cli`) and a command line function to convert Pickle files to Skops files (:func:`.cli._convert.main`). :pr:`249` by `Erin Aho`_ - Support more array-like data types for tabular data and list-like data types diff --git a/skops/io/_numpy.py b/skops/io/_numpy.py index 1f2cef82..d5243f15 100644 --- a/skops/io/_numpy.py +++ b/skops/io/_numpy.py @@ -165,7 +165,7 @@ def _construct(self): def random_generator_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: - bit_generator_state = obj.bit_generator.state + bit_generator_state = get_state(obj.bit_generator.state, save_context) res = { "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), @@ -183,15 +183,21 @@ def __init__( trusted: bool | Sequence[str] = False, ) -> None: super().__init__(state, load_context, trusted) - self.children = {"bit_generator_state": state["content"]["bit_generator"]} + self.children = { + "bit_generator_state": get_tree( + state["content"]["bit_generator"], load_context + ) + } self.trusted = self._get_trusted(trusted, [np.random.Generator]) def _construct(self): # first restore the state of the bit generator - bit_generator = gettype( - "numpy.random", self.children["bit_generator_state"]["bit_generator"] - )() - bit_generator.state = self.children["bit_generator_state"] + bit_generator_state = self.children["bit_generator_state"].construct() + bit_generator_cls = gettype( + "numpy.random", bit_generator_state["bit_generator"] + ) + bit_generator = bit_generator_cls() + bit_generator.state = bit_generator_state # next create the generator instance return gettype(self.module_name, self.class_name)(bit_generator=bit_generator) diff --git a/skops/io/_persist.py b/skops/io/_persist.py index 12796e05..a90cb1fc 100644 --- a/skops/io/_persist.py +++ b/skops/io/_persist.py @@ -16,7 +16,7 @@ # them. Old protocols are found in the 'old/' directory, with the protocol # version appended to the corresponding module name. modules = ["._general", "._numpy", "._scipy", "._sklearn"] -modules.extend([".old._general_v0"]) +modules.extend([".old._general_v0", ".old._numpy_v0"]) for module_name in modules: # register exposed functions for get_state and get_tree module = importlib.import_module(module_name, package="skops.io") diff --git a/skops/io/old/_numpy_v0.py b/skops/io/old/_numpy_v0.py new file mode 100644 index 00000000..04e258ee --- /dev/null +++ b/skops/io/old/_numpy_v0.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from typing import Any, Sequence + +import numpy as np + +from skops.io._audit import Node +from skops.io._utils import LoadContext, gettype + +PROTOCOL = 0 + + +class RandomGeneratorNode(Node): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: bool | Sequence[str] = False, + ) -> None: + super().__init__(state, load_context, trusted) + self.children = {"bit_generator_state": state["content"]["bit_generator"]} + self.trusted = self._get_trusted(trusted, [np.random.Generator]) + + def _construct(self): + # first restore the state of the bit generator + bit_generator = gettype( + "numpy.random", self.children["bit_generator_state"]["bit_generator"] + )() + bit_generator.state = self.children["bit_generator_state"] + + # next create the generator instance + return gettype(self.module_name, self.class_name)(bit_generator=bit_generator) + + +NODE_TYPE_MAPPING = { + ("RandomGeneratorNode", PROTOCOL): RandomGeneratorNode, +} diff --git a/skops/io/tests/_utils.py b/skops/io/tests/_utils.py index ce9c45d9..d355f959 100644 --- a/skops/io/tests/_utils.py +++ b/skops/io/tests/_utils.py @@ -219,9 +219,11 @@ def downgrade_state(*, data: bytes, keys: list[str], old_state: dict, protocol: data : bytes The old state, as generated by ``skops.io.dumps``. - keys : list of str + keys : list of str, or None The keys that lead to the old state. E.g. if we want to replace - ``state["foo"]["bar"]``, then keys should be ``["foo", "bar"]``. + ``state["foo"]["bar"]``, then keys should be ``["foo", "bar"]``. If + ``keys=None``, the whole schema is instead replaced by the old state + being passed. old_state : dict The old state, as would be produced by the old ``get_state`` function. @@ -239,17 +241,22 @@ def downgrade_state(*, data: bytes, keys: list[str], old_state: dict, protocol: with ZipFile(io.BytesIO(data), "r") as zip_file: schema = json.loads(zip_file.read("schema.json")) - # replace schema - schema["protocol"] = protocol + # replace state in schema using old state + if keys is None: + # replace all fields + schema = old_state + schema["__id__"] = id(schema) + else: + # replace specific field + state = schema + for key in keys[:-1]: + state = state[key] + state[keys[-1]] = old_state - # replace state using old state - state = schema - for key in keys[:-1]: - state = state[key] - state[keys[-1]] = old_state + # there has to be an __id__ field for memoization + state[keys[-1]]["__id__"] = id(schema) - # there has to be an __id__ field for memoization - state[keys[-1]]["__id__"] = id(schema) + schema["protocol"] = protocol # dump into bytes buffer = io.BytesIO() diff --git a/skops/io/tests/test_persist.py b/skops/io/tests/test_persist.py index 19a6c7eb..1724a4c1 100644 --- a/skops/io/tests/test_persist.py +++ b/skops/io/tests/test_persist.py @@ -421,7 +421,14 @@ def test_random_state(random_state): est = RandomStateEstimator(random_state=random_state).fit(None, None) est.random_state_.random(123) # move RNG forwards - loaded = loads(dumps(est), trusted=True) + dumped = dumps(est) + untrusted_types = get_untrusted_types(data=dumped) + loaded = loads(dumped, trusted=untrusted_types) + + if hasattr(est, "__dict__"): + # what to do if object has no __dict__, like Generator? + assert_params_equal(est.__dict__, loaded.__dict__) + rand_floats_expected = est.random_state_.random(100) rand_floats_loaded = loaded.random_state_.random(100) np.testing.assert_equal(rand_floats_loaded, rand_floats_expected) diff --git a/skops/io/tests/test_persist_old.py b/skops/io/tests/test_persist_old.py index 69f297fa..504b16c0 100644 --- a/skops/io/tests/test_persist_old.py +++ b/skops/io/tests/test_persist_old.py @@ -66,3 +66,51 @@ def old_function_get_state(obj, save_context): # check that loaded estimator is identical assert_params_equal(estimator.__dict__, loaded.__dict__) assert_method_outputs_equal(estimator, loaded, X) + + +@pytest.mark.parametrize( + "rng", + [ + np.random.default_rng(), + np.random.Generator(np.random.PCG64DXSM(seed=123)), + ], + ids=["default_rng", "Generator"], +) +def test_random_generator_v0(rng): + call_count = 0 + + # random_generator_get_state as it was for protocol 0 + def old_random_generator_get_state(obj, save_context): + # added for testing + nonlocal call_count + call_count += 1 + # end + + bit_generator_state = obj.bit_generator.state + res = { + "__class__": obj.__class__.__name__, + "__module__": get_module(type(obj)), + "__loader__": "RandomGeneratorNode", + "content": {"bit_generator": bit_generator_state}, + } + return res + + rng.random(123) # move RNG forwards + dumped = dumps(rng) + # importent: downgrade the whole state to mimic older version + downgraded = downgrade_state( + data=dumped, + keys=None, + old_state=old_random_generator_get_state(rng, None), + protocol=0, + ) + + # old loader only worked with trusted=True, see #329 + loaded = loads(downgraded, trusted=True) + + # sanity check: ensure that the old get_state function was really called + assert call_count == 1 + + rand_floats_expected = rng.random(100) + rand_floats_loaded = loaded.random(100) + np.testing.assert_equal(rand_floats_loaded, rand_floats_expected)