diff --git a/docs/changes.rst b/docs/changes.rst index 9c1895b5..24442b83 100644 --- a/docs/changes.rst +++ b/docs/changes.rst @@ -11,6 +11,9 @@ skops Changelog v0.13 ----- +- `Card` now requires a new parameter, `allow_pickle`, to call `get_model` with + models that are not `.skops` files. This change is to mitigate security risks + associated with pickles. :pr:`485` by `Io_no`_. v0.12 ----- diff --git a/skops/card/_model_card.py b/skops/card/_model_card.py index cb98cd21..b064c54b 100644 --- a/skops/card/_model_card.py +++ b/skops/card/_model_card.py @@ -227,7 +227,9 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}({nrows}x{ncols})" -def _load_model(model: Any, trusted=False) -> Any: +def _load_model( + model: Any, trusted: Optional[Sequence[str]] = None, allow_pickle: bool = False +) -> Any: """Return a model instance. Loads the model if provided a file path, if already a model instance return @@ -238,10 +240,14 @@ def _load_model(model: Any, trusted=False) -> Any: model : pathlib.Path, str, or sklearn estimator Path/str or the actual model instance. if a Path or str, loads the model. - trusted : bool, default=False + trusted: list of str, default=None Passed to :func:`skops.io.load` if the model is a file path and it's a `skops` file. + allow_pickle : bool, default=False + If `True`, allows loading models using `joblib.load`. This may lead to + security issues if the model file is not trustworthy. + Returns ------- model : object @@ -255,13 +261,28 @@ def _load_model(model: Any, trusted=False) -> Any: if not model_path.exists(): raise FileNotFoundError(f"File is not present: {model_path}") + if trusted and allow_pickle: + raise ValueError( + "`allow_pickle` cannot be `True` if `trusted` is not empty. " + "Pickles cannot be trusted or checked for security issues." + ) + + msg = "" try: if zipfile.is_zipfile(model_path): model = load(model_path, trusted=trusted) - else: + elif allow_pickle: model = joblib.load(model_path) + else: + msg = ( + "Model file is not a skops file, and allow_pickle is set to False. " + "Please set allow_pickle=True to load the model." + "This may lead to security issues if the model file is not trustworthy." + ) + raise RuntimeError(msg) except Exception as ex: - msg = f'An "{type(ex).__name__}" occurred during model loading.' + if not msg: + msg = f'"{type(ex).__name__}" occurred during model loading.' raise RuntimeError(msg) from ex return model @@ -310,10 +331,14 @@ class Card: not work, e.g. :meth:`Card.add_metrics`, since it's not clear where to put the metrics when there is no template or a custom template. - trusted: bool, default=False + trusted: list of str, default=None Passed to :func:`skops.io.load` if the model is a file path and it's a `skops` file. + allow_pickle: bool, default=False + If `True`, allows loading models using `joblib.load`. This may lead to + security issues if the model file is not trustworthy. + Attributes ---------- model: estimator object @@ -379,11 +404,13 @@ def __init__( model_diagram: bool | Literal["auto"] | str = "auto", template: Literal["skops"] | dict[str, str] | None = "skops", trusted: Optional[List[str]] = None, + allow_pickle: bool = False, ) -> None: self.model = model self.model_format = model_format self.template = template self.trusted = trusted + self.allow_pickle = allow_pickle self._data: dict[str, Section] = {} self._metrics: dict[str, str | float | int] = {} @@ -465,7 +492,7 @@ def get_model(self) -> Any: @cached_property def _model(self): - model = _load_model(self.model, self.trusted) + model = _load_model(self.model, self.trusted, self.allow_pickle) return model def add(self, folded: bool = False, **kwargs: str) -> Self: diff --git a/skops/card/tests/test_card.py b/skops/card/tests/test_card.py index 0eea01ba..29f68b7b 100644 --- a/skops/card/tests/test_card.py +++ b/skops/card/tests/test_card.py @@ -82,11 +82,17 @@ def test_load_model(suffix): _, save_file = save_model_to_file(model0, suffix) if suffix == ".skops": untrusted_types = get_untrusted_types(file=save_file) + allow_pickle = False else: untrusted_types = None - loaded_model_str = _load_model(save_file, trusted=untrusted_types) + allow_pickle = True + loaded_model_str = _load_model( + save_file, trusted=untrusted_types, allow_pickle=allow_pickle + ) save_file_path = Path(save_file) - loaded_model_path = _load_model(save_file_path, trusted=untrusted_types) + loaded_model_path = _load_model( + save_file_path, trusted=untrusted_types, allow_pickle=allow_pickle + ) loaded_model_instance = _load_model(model0, trusted=untrusted_types) assert loaded_model_str.param_1 == 10 @@ -94,6 +100,31 @@ def test_load_model(suffix): assert loaded_model_instance.param_1 == 10 +@pytest.mark.parametrize("suffix", [".pkl", ".pickle"]) +def test_load_model_exception_allow_pickle(suffix): + model0 = MyRegressor(param_1=10) + _, save_file = save_model_to_file(model0, suffix) + + with pytest.raises( + RuntimeError, + match=( + "Model file is not a skops file, and allow_pickle is set to False. " + "Please set allow_pickle=True to load the model." + "This may lead to security issues if the model file is not trustworthy." + ), + ): + _load_model(save_file, trusted=None, allow_pickle=False) + + with pytest.raises( + ValueError, + match=( + "`allow_pickle` cannot be `True` if `trusted` is not empty. " + "Pickles cannot be trusted or checked for security issues." + ), + ): + _load_model(save_file, trusted=[""], allow_pickle=True) + + @pytest.fixture def model_card(model_diagram=True): model = fit_model() @@ -166,7 +197,9 @@ def test_model_caching(skops_model_card, iris_skops_file, destination_path): new_model = MyClassifier(param_1=10) # mock _load_model, it still loads the model but we can track call count - mock_load_model = mock.Mock(side_effect=load) + mock_load_model = mock.Mock( + side_effect=lambda path, trusted, _: load(path, trusted=trusted) + ) card = Card(iris_skops_file, trusted=[MyClassifier]) with mock.patch("skops.card._model_card._load_model", mock_load_model): model1 = card.get_model() @@ -1133,7 +1166,7 @@ def path_to_card(self, path, suffix): if suffix == ".skops": card = Card(model=path, trusted=get_untrusted_types(file=path)) else: - card = Card(model=path) + card = Card(model=path, allow_pickle=True) return card @pytest.mark.parametrize("meth", [repr, str]) @@ -1178,7 +1211,7 @@ def test_load_model_exception(self, meth, suffix): os.close(file_handle) with pytest.raises(Exception, match="occurred during model loading."): - card = Card(file_name) + card = Card(file_name, allow_pickle=True) meth(card) @pytest.mark.parametrize("meth", [repr, str])