diff --git a/docs/changes.rst b/docs/changes.rst index 034074a4..bea77aab 100644 --- a/docs/changes.rst +++ b/docs/changes.rst @@ -14,6 +14,8 @@ v0.4 - :func:`.io.dump` and :func:`.io.load` now work with file like objects, which means you can use them with the ``with open(...) as f: dump(obj, f)`` pattern, like you'd do with ``pickle``. :pr:`234` by `Benjamin Bossan`_. +- All `scikit-learn` estimators are trusted by default. + :pr:`237` by :user:`Edoardo Abati `. v0.3 ---- diff --git a/skops/io/_general.py b/skops/io/_general.py index 7f27fbcb..126bcfaf 100644 --- a/skops/io/_general.py +++ b/skops/io/_general.py @@ -8,7 +8,7 @@ import numpy as np from ._audit import Node, get_tree -from ._trusted_types import PRIMITIVE_TYPE_NAMES +from ._trusted_types import PRIMITIVE_TYPE_NAMES, SKLEARN_ESTIMATOR_TYPE_NAMES from ._utils import ( LoadContext, SaveContext, @@ -383,7 +383,7 @@ def __init__( self.children = {"attrs": attrs} # TODO: what do we trust? - self.trusted = self._get_trusted(trusted, []) + self.trusted = self._get_trusted(trusted, default=SKLEARN_ESTIMATOR_TYPE_NAMES) def _construct(self): cls = gettype(self.module_name, self.class_name) diff --git a/skops/io/_trusted_types.py b/skops/io/_trusted_types.py index e3c38ffd..1ef1b826 100644 --- a/skops/io/_trusted_types.py +++ b/skops/io/_trusted_types.py @@ -1,3 +1,13 @@ +from sklearn.utils import all_estimators + +from ._utils import get_type_name + PRIMITIVES_TYPES = [int, float, str, bool] PRIMITIVE_TYPE_NAMES = ["builtins." + t.__name__ for t in PRIMITIVES_TYPES] + +SKLEARN_ESTIMATOR_TYPE_NAMES = [ + get_type_name(estimator_class) + for _, estimator_class in all_estimators() + if get_type_name(estimator_class).startswith("sklearn.") +] diff --git a/skops/io/tests/test_audit.py b/skops/io/tests/test_audit.py index a1ae0188..71914b4d 100644 --- a/skops/io/tests/test_audit.py +++ b/skops/io/tests/test_audit.py @@ -162,12 +162,4 @@ def test_complex_pipeline_untrusted_set(): untrusted = get_untrusted_types(data=dumps(clf)) type_names = [x.split(".")[-1] for x in untrusted] - assert type_names == [ - "sqrt", - "square", - "LogisticRegression", - "FeatureUnion", - "Pipeline", - "StandardScaler", - "FunctionTransformer", - ] + assert type_names == ["sqrt", "square"] diff --git a/skops/io/tests/test_persist.py b/skops/io/tests/test_persist.py index c38871c6..012ab7f7 100644 --- a/skops/io/tests/test_persist.py +++ b/skops/io/tests/test_persist.py @@ -56,6 +56,7 @@ from skops.io import dump, dumps, get_untrusted_types, load, loads from skops.io._audit import NODE_TYPE_MAPPING, get_tree from skops.io._sklearn import UNSUPPORTED_TYPES +from skops.io._trusted_types import SKLEARN_ESTIMATOR_TYPE_NAMES from skops.io._utils import LoadContext, SaveContext, _get_state, get_state from skops.io.exceptions import UnsupportedTypeException @@ -468,7 +469,7 @@ def get_input(estimator): @pytest.mark.parametrize( "estimator", _tested_estimators(), ids=_get_check_estimator_ids ) -def test_can_persist_fitted(estimator, request): +def test_can_persist_fitted(estimator): """Check that fitted estimators can be persisted and return the right results.""" set_random_state(estimator, random_state=0) @@ -491,6 +492,8 @@ def test_can_persist_fitted(estimator, request): loaded = loads(dumped, trusted=untrusted_types) assert_params_equal(estimator.__dict__, loaded.__dict__) + assert not any(type_ in SKLEARN_ESTIMATOR_TYPE_NAMES for type_ in untrusted_types) + for method in [ "predict", "predict_proba",