Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions docs/changes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ skops Changelog

v0.6
----
- Added tabular regression example. :pr:`254` by :user:`Thomas Lazarus <lazarust>`.
- Add tabular regression example. :pr:`254` by :user:`Thomas Lazarus <lazarust>`.
- All public ``scipy.special`` ufuncs (Universal Functions) are trusted by default
by :func:`.io.load`. :pr:`295` by :user:`Omar Arab Oghli <omar-araboghli>`.
- Add a new function :func:`skops.card.Card.add_metric_frame` to help users
Expand All @@ -28,10 +28,13 @@ v0.6
- Add possibility to visualize a skops object and show untrusted types by using
:func:`skops.io.visualize`. For colored output, install `rich`: `pip install
rich`. :pr:`317` by `Benjamin Bossan`_.
- Fix issue with persisting :class:`numpy.random.Generator` using the skops
format (the object could be loaded correctly but security could not be
checked). :pr:`331` by `Benjamin Bossan`_.

v0.5
----
- Added CLI entrypoint support (:func:`.cli.entrypoint.main_cli`)
- Add CLI entrypoint support (:func:`.cli.entrypoint.main_cli`)
and a command line function to convert Pickle files
to Skops files (:func:`.cli._convert.main`). :pr:`249` by `Erin Aho`_
- Support more array-like data types for tabular data and list-like data types
Expand Down
18 changes: 12 additions & 6 deletions skops/io/_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def _construct(self):


def random_generator_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]:
bit_generator_state = obj.bit_generator.state
bit_generator_state = get_state(obj.bit_generator.state, save_context)
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
Expand All @@ -183,15 +183,21 @@ def __init__(
trusted: bool | Sequence[str] = False,
) -> None:
super().__init__(state, load_context, trusted)
self.children = {"bit_generator_state": state["content"]["bit_generator"]}
self.children = {
"bit_generator_state": get_tree(
state["content"]["bit_generator"], load_context
)
}
self.trusted = self._get_trusted(trusted, [np.random.Generator])

def _construct(self):
# first restore the state of the bit generator
bit_generator = gettype(
"numpy.random", self.children["bit_generator_state"]["bit_generator"]
)()
bit_generator.state = self.children["bit_generator_state"]
bit_generator_state = self.children["bit_generator_state"].construct()
bit_generator_cls = gettype(
"numpy.random", bit_generator_state["bit_generator"]
)
bit_generator = bit_generator_cls()
bit_generator.state = bit_generator_state

# next create the generator instance
return gettype(self.module_name, self.class_name)(bit_generator=bit_generator)
Expand Down
2 changes: 1 addition & 1 deletion skops/io/_persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# them. Old protocols are found in the 'old/' directory, with the protocol
# version appended to the corresponding module name.
modules = ["._general", "._numpy", "._scipy", "._sklearn"]
modules.extend([".old._general_v0"])
modules.extend([".old._general_v0", ".old._numpy_v0"])
for module_name in modules:
# register exposed functions for get_state and get_tree
module = importlib.import_module(module_name, package="skops.io")
Expand Down
37 changes: 37 additions & 0 deletions skops/io/old/_numpy_v0.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from __future__ import annotations

from typing import Any, Sequence

import numpy as np

from skops.io._audit import Node
from skops.io._utils import LoadContext, gettype

PROTOCOL = 0


class RandomGeneratorNode(Node):
def __init__(
self,
state: dict[str, Any],
load_context: LoadContext,
trusted: bool | Sequence[str] = False,
) -> None:
super().__init__(state, load_context, trusted)
self.children = {"bit_generator_state": state["content"]["bit_generator"]}
self.trusted = self._get_trusted(trusted, [np.random.Generator])

def _construct(self):
# first restore the state of the bit generator
bit_generator = gettype(
"numpy.random", self.children["bit_generator_state"]["bit_generator"]
)()
bit_generator.state = self.children["bit_generator_state"]

# next create the generator instance
return gettype(self.module_name, self.class_name)(bit_generator=bit_generator)


NODE_TYPE_MAPPING = {
("RandomGeneratorNode", PROTOCOL): RandomGeneratorNode,
}
29 changes: 18 additions & 11 deletions skops/io/tests/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,11 @@ def downgrade_state(*, data: bytes, keys: list[str], old_state: dict, protocol:
data : bytes
The old state, as generated by ``skops.io.dumps``.

keys : list of str
keys : list of str, or None
The keys that lead to the old state. E.g. if we want to replace
``state["foo"]["bar"]``, then keys should be ``["foo", "bar"]``.
``state["foo"]["bar"]``, then keys should be ``["foo", "bar"]``. If
``keys=None``, the whole schema is instead replaced by the old state
being passed.

old_state : dict
The old state, as would be produced by the old ``get_state`` function.
Expand All @@ -239,17 +241,22 @@ def downgrade_state(*, data: bytes, keys: list[str], old_state: dict, protocol:
with ZipFile(io.BytesIO(data), "r") as zip_file:
schema = json.loads(zip_file.read("schema.json"))

# replace schema
schema["protocol"] = protocol
# replace state in schema using old state
if keys is None:
# replace all fields
schema = old_state
schema["__id__"] = id(schema)
else:
# replace specific field
state = schema
for key in keys[:-1]:
state = state[key]
state[keys[-1]] = old_state

# replace state using old state
state = schema
for key in keys[:-1]:
state = state[key]
state[keys[-1]] = old_state
# there has to be an __id__ field for memoization
state[keys[-1]]["__id__"] = id(schema)

# there has to be an __id__ field for memoization
state[keys[-1]]["__id__"] = id(schema)
schema["protocol"] = protocol

# dump into bytes
buffer = io.BytesIO()
Expand Down
9 changes: 8 additions & 1 deletion skops/io/tests/test_persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,14 @@ def test_random_state(random_state):
est = RandomStateEstimator(random_state=random_state).fit(None, None)
est.random_state_.random(123) # move RNG forwards

loaded = loads(dumps(est), trusted=True)
dumped = dumps(est)
untrusted_types = get_untrusted_types(data=dumped)
loaded = loads(dumped, trusted=untrusted_types)

if hasattr(est, "__dict__"):
# what to do if object has no __dict__, like Generator?
assert_params_equal(est.__dict__, loaded.__dict__)

rand_floats_expected = est.random_state_.random(100)
rand_floats_loaded = loaded.random_state_.random(100)
np.testing.assert_equal(rand_floats_loaded, rand_floats_expected)
Expand Down
48 changes: 48 additions & 0 deletions skops/io/tests/test_persist_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,51 @@ def old_function_get_state(obj, save_context):
# check that loaded estimator is identical
assert_params_equal(estimator.__dict__, loaded.__dict__)
assert_method_outputs_equal(estimator, loaded, X)


@pytest.mark.parametrize(
"rng",
[
np.random.default_rng(),
np.random.Generator(np.random.PCG64DXSM(seed=123)),
],
ids=["default_rng", "Generator"],
)
def test_random_generator_v0(rng):
call_count = 0

# random_generator_get_state as it was for protocol 0
def old_random_generator_get_state(obj, save_context):
# added for testing
nonlocal call_count
call_count += 1
# end

bit_generator_state = obj.bit_generator.state
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
"__loader__": "RandomGeneratorNode",
"content": {"bit_generator": bit_generator_state},
}
return res

rng.random(123) # move RNG forwards
dumped = dumps(rng)
# importent: downgrade the whole state to mimic older version
downgraded = downgrade_state(
data=dumped,
keys=None,
old_state=old_random_generator_get_state(rng, None),
protocol=0,
)

# old loader only worked with trusted=True, see #329
loaded = loads(downgraded, trusted=True)

# sanity check: ensure that the old get_state function was really called
assert call_count == 1

rand_floats_expected = rng.random(100)
rand_floats_loaded = loaded.random(100)
np.testing.assert_equal(rand_floats_loaded, rand_floats_expected)