diff --git a/pyproject.toml b/pyproject.toml index c957920e..1ac0b580 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,6 @@ omit = [ ] [tool.mypy] -exclude = "(\\w+/)*test_\\w+\\.py$" +exclude = "(\\w+/)*test_\\w+\\.py$|old" ignore_missing_imports = true no_implicit_optional = true diff --git a/skops/io/_audit.py b/skops/io/_audit.py index 067b13c5..d2426473 100644 --- a/skops/io/_audit.py +++ b/skops/io/_audit.py @@ -2,14 +2,16 @@ import io from contextlib import contextmanager -from typing import Any, Generator, Literal, Sequence, Type, Union +from typing import Any, Dict, Generator, List, Literal, Optional, Sequence, Type, Union from ._protocol import PROTOCOL -from ._trusted_types import PRIMITIVE_TYPE_NAMES from ._utils import LoadContext, get_module, get_type_paths from .exceptions import UntrustedTypesFoundException NODE_TYPE_MAPPING: dict[tuple[str, int], Node] = {} +VALID_NODE_CHILD_TYPES = Optional[ + Union["Node", List["Node"], Dict[str, "Node"], Type, str, io.BytesIO] +] def check_type( @@ -168,7 +170,7 @@ def __init__( # 3. set self.children, where children are states of child nodes; do not # construct the children objects yet self.trusted = self._get_trusted(trusted, []) - self.children: dict[str, Any] = {} + self.children: dict[str, VALID_NODE_CHILD_TYPES] = {} def construct(self): """Construct the object. @@ -269,15 +271,11 @@ def get_unsafe_set(self) -> set[str]: if not check_type(get_module(child), child.__name__, self.trusted): # if the child is a type, we check its safety res.add(get_module(child) + "." + child.__name__) - elif isinstance(child, io.BytesIO): + elif isinstance(child, (io.BytesIO, str)): # We trust BytesIO objects, which are read by other - # libraries such as numpy, scipy. - continue - elif check_type( - get_module(child), child.__class__.__name__, PRIMITIVE_TYPE_NAMES - ): - # if the child is a primitive type, we don't need to check its - # safety. + # libraries such as numpy, scipy. We trust str but have to + # be careful that anything with str is dealt with + # appropriately. continue else: raise ValueError( diff --git a/skops/io/_sklearn.py b/skops/io/_sklearn.py index 4d302267..ce9c3969 100644 --- a/skops/io/_sklearn.py +++ b/skops/io/_sklearn.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Callable, Sequence, Type +from typing import Any, Sequence, Type from sklearn.cluster import Birch @@ -96,7 +96,7 @@ def __init__( self, state: dict[str, Any], load_context: LoadContext, - constructor: Type[Any] | Callable[..., Any], + constructor: Type[Any], trusted: bool | Sequence[str] = False, ) -> None: super().__init__(state, load_context, trusted) diff --git a/skops/io/_visualize.py b/skops/io/_visualize.py index 75269bf0..8e31c2fc 100644 --- a/skops/io/_visualize.py +++ b/skops/io/_visualize.py @@ -7,7 +7,7 @@ from typing import Any, Callable, Iterator, Literal from zipfile import ZipFile -from ._audit import Node, get_tree +from ._audit import VALID_NODE_CHILD_TYPES, Node, get_tree from ._general import FunctionNode, JsonNode, ListNode from ._numpy import NdArrayNode from ._scipy import SparseMatrixNode @@ -168,7 +168,7 @@ def pretty_print_tree( def walk_tree( - node: Node | dict[str, Node] | list[Node], + node: VALID_NODE_CHILD_TYPES | dict[str, VALID_NODE_CHILD_TYPES], node_name: str = "root", level: int = 0, is_last: bool = False, diff --git a/skops/io/tests/test_persist.py b/skops/io/tests/test_persist.py index 1724a4c1..eb7c5107 100644 --- a/skops/io/tests/test_persist.py +++ b/skops/io/tests/test_persist.py @@ -263,7 +263,9 @@ def _unsupported_estimators(type_filter=None): ) def test_can_persist_non_fitted(estimator): """Check that non-fitted estimators can be persisted.""" - loaded = loads(dumps(estimator), trusted=True) + dumped = dumps(estimator) + untrusted_types = get_untrusted_types(data=dumped) + loaded = loads(dumped, trusted=untrusted_types) assert_params_equal(estimator.get_params(), loaded.get_params()) @@ -458,7 +460,9 @@ def split(self, X, **kwargs): ) def test_cross_validator(cv): est = CVEstimator(cv=cv).fit(None, None) - loaded = loads(dumps(est), trusted=True) + dumped = dumps(est) + untrusted_types = get_untrusted_types(data=dumped) + loaded = loads(dumped, trusted=untrusted_types) X, y = make_classification( n_samples=N_SAMPLES, n_features=N_FEATURES, random_state=0 ) @@ -500,7 +504,9 @@ def test_numpy_object_dtype_2d_array(transpose): if transpose: est.obj_array_ = est.obj_array_.T - loaded = loads(dumps(est), trusted=True) + dumped = dumps(est) + untrusted_types = get_untrusted_types(data=dumped) + loaded = loads(dumped, trusted=untrusted_types) assert_params_equal(est.__dict__, loaded.__dict__) @@ -615,7 +621,8 @@ def test_identical_numpy_arrays_not_duplicated(): X = np.random.random((10, 5)) estimator = EstimatorIdenticalArrays().fit(X) dumped = dumps(estimator) - loaded = loads(dumped, trusted=True) + untrusted_types = get_untrusted_types(data=dumped) + loaded = loads(dumped, trusted=untrusted_types) assert_params_equal(estimator.__dict__, loaded.__dict__) # check number of numpy arrays stored on disk @@ -719,7 +726,9 @@ def test_for_base_case_returns_as_expected(self): bound_function = obj.bound_method transformer = FunctionTransformer(func=bound_function) - loaded_transformer = loads(dumps(transformer), trusted=True) + dumped = dumps(transformer) + untrusted_types = get_untrusted_types(data=dumped) + loaded_transformer = loads(dumped, trusted=untrusted_types) loaded_obj = loaded_transformer.func.__self__ self.assert_transformer_persisted_correctly(loaded_transformer, transformer) @@ -736,7 +745,9 @@ def test_when_object_is_changed_after_init_works_as_expected(self): transformer = FunctionTransformer(func=bound_function) - loaded_transformer = loads(dumps(transformer), trusted=True) + dumped = dumps(transformer) + untrusted_types = get_untrusted_types(data=dumped) + loaded_transformer = loads(dumped, trusted=untrusted_types) loaded_obj = loaded_transformer.func.__self__ self.assert_transformer_persisted_correctly(loaded_transformer, transformer) @@ -749,19 +760,23 @@ def test_works_when_given_multiple_bound_methods_attached_to_single_instance(sel func=obj.bound_method, inverse_func=obj.other_bound_method ) - loaded_transformer = loads(dumps(transformer), trusted=True) + dumped = dumps(transformer) + untrusted_types = get_untrusted_types(data=dumped) + loaded_transformer = loads(dumped, trusted=untrusted_types) # check that both func and inverse_func are from the same object instance loaded_0 = loaded_transformer.func.__self__ loaded_1 = loaded_transformer.inverse_func.__self__ assert loaded_0 is loaded_1 - @pytest.mark.xfail(reason="Failing due to circular self reference") + @pytest.mark.xfail(reason="Failing due to circular self reference", strict=True) def test_scipy_stats(self, tmp_path): from scipy import stats estimator = FunctionTransformer(func=stats.zipf) - loads(dumps(estimator), trusted=True) + dumped = dumps(estimator) + untrusted_types = get_untrusted_types(data=dumped) + loads(dumped, trusted=untrusted_types) class CustomEstimator(BaseEstimator): @@ -862,7 +877,9 @@ def test_dump_and_load_with_file_wrapper(tmp_path): ) def test_when_given_object_referenced_twice_loads_as_one_object(obj): an_object = {"obj_1": obj, "obj_2": obj} - persisted_object = loads(dumps(an_object), trusted=True) + dumped = dumps(an_object) + untrusted_types = get_untrusted_types(data=dumped) + persisted_object = loads(dumped, trusted=untrusted_types) assert persisted_object["obj_1"] is persisted_object["obj_2"] @@ -876,7 +893,9 @@ def fit(self, X, y, **fit_params): def test_estimator_with_bytes(): est = EstimatorWithBytes().fit(None, None) - loaded = loads(dumps(est), trusted=True) + dumped = dumps(est) + untrusted_types = get_untrusted_types(data=dumped) + loaded = loads(dumped, trusted=untrusted_types) assert_params_equal(est.__dict__, loaded.__dict__) @@ -934,13 +953,17 @@ def test_persist_operator(op): _, func = op # unfitted est = FunctionTransformer(func) - loaded = loads(dumps(est), trusted=True) + dumped = dumps(est) + untrusted_types = get_untrusted_types(data=dumped) + loaded = loads(dumped, trusted=untrusted_types) assert_params_equal(est.__dict__, loaded.__dict__) # fitted X, y = get_input(est) est.fit(X, y) - loaded = loads(dumps(est), trusted=True) + dumped = dumps(est) + untrusted_types = get_untrusted_types(data=dumped) + loaded = loads(dumped, trusted=untrusted_types) assert_params_equal(est.__dict__, loaded.__dict__) # Technically, we don't need to call transform. However, if this is skipped, @@ -973,7 +996,8 @@ def test_persist_function(func): estimator.fit(X, y) dumped = dumps(estimator) - loaded = loads(dumped, trusted=True) + untrusted_types = get_untrusted_types(data=dumped) + loaded = loads(dumped, trusted=untrusted_types) # check that loaded estimator is identical assert_params_equal(estimator.__dict__, loaded.__dict__)