Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/changes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ skops Changelog
v0.6
----
- Added tabular regression example. :pr: `254` by `Thomas Lazarus`
- All public ``scipy.special`` ufuncs (Universal Functions) are trusted by default
by :func:`.io.load`. :pr:`295` by :user:`Omar Arab Oghli <omar-araboghli>`.

v0.5
----
Expand Down
10 changes: 7 additions & 3 deletions skops/io/_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
import numpy as np

from ._audit import Node, get_tree
from ._trusted_types import PRIMITIVE_TYPE_NAMES, SKLEARN_ESTIMATOR_TYPE_NAMES
from ._trusted_types import (
PRIMITIVE_TYPE_NAMES,
SCIPY_UFUNC_TYPE_NAMES,
SKLEARN_ESTIMATOR_TYPE_NAMES,
)
from ._utils import (
LoadContext,
SaveContext,
Expand Down Expand Up @@ -195,7 +199,7 @@ def __init__(
) -> None:
super().__init__(state, load_context, trusted)
# TODO: what do we trust?
self.trusted = self._get_trusted(trusted, [])
self.trusted = self._get_trusted(trusted, default=SCIPY_UFUNC_TYPE_NAMES)
self.children = {"content": state["content"]}

def _construct(self):
Expand All @@ -212,7 +216,7 @@ def _get_function_name(self) -> str:
)

def get_unsafe_set(self) -> set[str]:
if self.trusted is True:
if (self.trusted is True) or (self._get_function_name() in self.trusted):
return set()

return {self._get_function_name()}
Expand Down
13 changes: 13 additions & 0 deletions skops/io/_trusted_types.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import numpy as np
import scipy
from sklearn.utils import all_estimators

from ._utils import get_type_name
Expand All @@ -11,3 +13,14 @@
for _, estimator_class in all_estimators()
if get_type_name(estimator_class).startswith("sklearn.")
]

SCIPY_UFUNC_TYPE_NAMES = sorted(
set(
[
get_type_name(getattr(scipy.special, attr))
for attr in dir(scipy.special)
if isinstance(getattr(scipy.special, attr), np.ufunc)
and get_type_name(getattr(scipy.special, attr)).startswith("scipy")
]
)
)
19 changes: 17 additions & 2 deletions skops/io/tests/test_persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@
from skops.io import dump, dumps, get_untrusted_types, load, loads
from skops.io._audit import NODE_TYPE_MAPPING, get_tree
from skops.io._sklearn import UNSUPPORTED_TYPES
from skops.io._trusted_types import SKLEARN_ESTIMATOR_TYPE_NAMES
from skops.io._utils import LoadContext, SaveContext, _get_state, get_state
from skops.io._trusted_types import SCIPY_UFUNC_TYPE_NAMES, SKLEARN_ESTIMATOR_TYPE_NAMES
from skops.io._utils import LoadContext, SaveContext, _get_state, get_state, gettype
from skops.io.exceptions import UnsupportedTypeException
from skops.io.tests._utils import assert_method_outputs_equal, assert_params_equal

Expand Down Expand Up @@ -221,6 +221,12 @@ def _tested_estimators(type_filter=None):
)


def _tested_ufuncs():
for full_name in SCIPY_UFUNC_TYPE_NAMES:
module_name, _, ufunc_name = full_name.rpartition(".")
yield gettype(module_name=module_name, cls_or_func=ufunc_name)


def _unsupported_estimators(type_filter=None):
for name, Estimator in all_estimators(type_filter=type_filter):
if Estimator not in UNSUPPORTED_TYPES:
Expand Down Expand Up @@ -345,9 +351,18 @@ def test_can_persist_fitted(estimator):
assert_params_equal(estimator.__dict__, loaded.__dict__)

assert not any(type_ in SKLEARN_ESTIMATOR_TYPE_NAMES for type_ in untrusted_types)
assert not any(type_ in SCIPY_UFUNC_TYPE_NAMES for type_ in untrusted_types)
assert_method_outputs_equal(estimator, loaded, X)


@pytest.mark.parametrize("ufunc", _tested_ufuncs(), ids=SCIPY_UFUNC_TYPE_NAMES)
def test_can_trust_ufuncs(ufunc):
dumped = dumps(ufunc)
untrusted_types = get_untrusted_types(data=dumped)
assert len(untrusted_types) == 0
# TODO: extend with numpy ufuncs


@pytest.mark.parametrize(
"estimator", _unsupported_estimators(), ids=_get_check_estimator_ids
)
Expand Down