Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/changes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ v0.4
- :func:`.io.dump` and :func:`.io.load` now work with file like objects,
which means you can use them with the ``with open(...) as f: dump(obj, f)``
pattern, like you'd do with ``pickle``. :pr:`234` by `Benjamin Bossan`_.
- All `scikit-learn` estimators are trusted by default.
:pr:`237` by :user:`Edoardo Abati <EdAbati>`.

v0.3
----
Expand Down
4 changes: 2 additions & 2 deletions skops/io/_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np

from ._audit import Node, get_tree
from ._trusted_types import PRIMITIVE_TYPE_NAMES
from ._trusted_types import PRIMITIVE_TYPE_NAMES, SKLEARN_ESTIMATOR_TYPE_NAMES
from ._utils import (
LoadContext,
SaveContext,
Expand Down Expand Up @@ -383,7 +383,7 @@ def __init__(

self.children = {"attrs": attrs}
# TODO: what do we trust?
self.trusted = self._get_trusted(trusted, [])
self.trusted = self._get_trusted(trusted, default=SKLEARN_ESTIMATOR_TYPE_NAMES)

def _construct(self):
cls = gettype(self.module_name, self.class_name)
Expand Down
10 changes: 10 additions & 0 deletions skops/io/_trusted_types.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
from sklearn.utils import all_estimators

from ._utils import get_type_name

PRIMITIVES_TYPES = [int, float, str, bool]

PRIMITIVE_TYPE_NAMES = ["builtins." + t.__name__ for t in PRIMITIVES_TYPES]

SKLEARN_ESTIMATOR_TYPE_NAMES = [
get_type_name(estimator_class)
for _, estimator_class in all_estimators()
if get_type_name(estimator_class).startswith("sklearn.")
]
10 changes: 1 addition & 9 deletions skops/io/tests/test_audit.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,12 +162,4 @@ def test_complex_pipeline_untrusted_set():

untrusted = get_untrusted_types(data=dumps(clf))
type_names = [x.split(".")[-1] for x in untrusted]
assert type_names == [
"sqrt",
"square",
"LogisticRegression",
"FeatureUnion",
"Pipeline",
"StandardScaler",
"FunctionTransformer",
]
assert type_names == ["sqrt", "square"]
5 changes: 4 additions & 1 deletion skops/io/tests/test_persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from skops.io import dump, dumps, get_untrusted_types, load, loads
from skops.io._audit import NODE_TYPE_MAPPING, get_tree
from skops.io._sklearn import UNSUPPORTED_TYPES
from skops.io._trusted_types import SKLEARN_ESTIMATOR_TYPE_NAMES
from skops.io._utils import LoadContext, SaveContext, _get_state, get_state
from skops.io.exceptions import UnsupportedTypeException

Expand Down Expand Up @@ -468,7 +469,7 @@ def get_input(estimator):
@pytest.mark.parametrize(
"estimator", _tested_estimators(), ids=_get_check_estimator_ids
)
def test_can_persist_fitted(estimator, request):
Comment thread
adrinjalali marked this conversation as resolved.
def test_can_persist_fitted(estimator):
"""Check that fitted estimators can be persisted and return the right results."""
set_random_state(estimator, random_state=0)

Expand All @@ -491,6 +492,8 @@ def test_can_persist_fitted(estimator, request):
loaded = loads(dumped, trusted=untrusted_types)
assert_params_equal(estimator.__dict__, loaded.__dict__)

assert not any(type_ in SKLEARN_ESTIMATOR_TYPE_NAMES for type_ in untrusted_types)

for method in [
"predict",
"predict_proba",
Expand Down