diff --git a/docs/changes.rst b/docs/changes.rst index b2bdfe68..28fa9975 100644 --- a/docs/changes.rst +++ b/docs/changes.rst @@ -12,6 +12,8 @@ skops Changelog v0.6 ---- - Added tabular regression example. :pr: `254` by `Thomas Lazarus` +- All public ``scipy.special`` ufuncs (Universal Functions) are trusted by default + by :func:`.io.load`. :pr:`295` by :user:`Omar Arab Oghli `. v0.5 ---- diff --git a/skops/io/_general.py b/skops/io/_general.py index 9bc2254a..51a184f8 100644 --- a/skops/io/_general.py +++ b/skops/io/_general.py @@ -10,7 +10,11 @@ import numpy as np from ._audit import Node, get_tree -from ._trusted_types import PRIMITIVE_TYPE_NAMES, SKLEARN_ESTIMATOR_TYPE_NAMES +from ._trusted_types import ( + PRIMITIVE_TYPE_NAMES, + SCIPY_UFUNC_TYPE_NAMES, + SKLEARN_ESTIMATOR_TYPE_NAMES, +) from ._utils import ( LoadContext, SaveContext, @@ -195,7 +199,7 @@ def __init__( ) -> None: super().__init__(state, load_context, trusted) # TODO: what do we trust? - self.trusted = self._get_trusted(trusted, []) + self.trusted = self._get_trusted(trusted, default=SCIPY_UFUNC_TYPE_NAMES) self.children = {"content": state["content"]} def _construct(self): @@ -212,7 +216,7 @@ def _get_function_name(self) -> str: ) def get_unsafe_set(self) -> set[str]: - if self.trusted is True: + if (self.trusted is True) or (self._get_function_name() in self.trusted): return set() return {self._get_function_name()} diff --git a/skops/io/_trusted_types.py b/skops/io/_trusted_types.py index 1ef1b826..39e55573 100644 --- a/skops/io/_trusted_types.py +++ b/skops/io/_trusted_types.py @@ -1,3 +1,5 @@ +import numpy as np +import scipy from sklearn.utils import all_estimators from ._utils import get_type_name @@ -11,3 +13,14 @@ for _, estimator_class in all_estimators() if get_type_name(estimator_class).startswith("sklearn.") ] + +SCIPY_UFUNC_TYPE_NAMES = sorted( + set( + [ + get_type_name(getattr(scipy.special, attr)) + for attr in dir(scipy.special) + if isinstance(getattr(scipy.special, attr), np.ufunc) + and get_type_name(getattr(scipy.special, attr)).startswith("scipy") + ] + ) +) diff --git a/skops/io/tests/test_persist.py b/skops/io/tests/test_persist.py index bd9c1e20..c12aa8d7 100644 --- a/skops/io/tests/test_persist.py +++ b/skops/io/tests/test_persist.py @@ -51,8 +51,8 @@ from skops.io import dump, dumps, get_untrusted_types, load, loads from skops.io._audit import NODE_TYPE_MAPPING, get_tree from skops.io._sklearn import UNSUPPORTED_TYPES -from skops.io._trusted_types import SKLEARN_ESTIMATOR_TYPE_NAMES -from skops.io._utils import LoadContext, SaveContext, _get_state, get_state +from skops.io._trusted_types import SCIPY_UFUNC_TYPE_NAMES, SKLEARN_ESTIMATOR_TYPE_NAMES +from skops.io._utils import LoadContext, SaveContext, _get_state, get_state, gettype from skops.io.exceptions import UnsupportedTypeException from skops.io.tests._utils import assert_method_outputs_equal, assert_params_equal @@ -221,6 +221,12 @@ def _tested_estimators(type_filter=None): ) +def _tested_ufuncs(): + for full_name in SCIPY_UFUNC_TYPE_NAMES: + module_name, _, ufunc_name = full_name.rpartition(".") + yield gettype(module_name=module_name, cls_or_func=ufunc_name) + + def _unsupported_estimators(type_filter=None): for name, Estimator in all_estimators(type_filter=type_filter): if Estimator not in UNSUPPORTED_TYPES: @@ -345,9 +351,18 @@ def test_can_persist_fitted(estimator): assert_params_equal(estimator.__dict__, loaded.__dict__) assert not any(type_ in SKLEARN_ESTIMATOR_TYPE_NAMES for type_ in untrusted_types) + assert not any(type_ in SCIPY_UFUNC_TYPE_NAMES for type_ in untrusted_types) assert_method_outputs_equal(estimator, loaded, X) +@pytest.mark.parametrize("ufunc", _tested_ufuncs(), ids=SCIPY_UFUNC_TYPE_NAMES) +def test_can_trust_ufuncs(ufunc): + dumped = dumps(ufunc) + untrusted_types = get_untrusted_types(data=dumped) + assert len(untrusted_types) == 0 + # TODO: extend with numpy ufuncs + + @pytest.mark.parametrize( "estimator", _unsupported_estimators(), ids=_get_check_estimator_ids )