Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/changes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ skops Changelog

v0.13
-----
- `Card` now requires a new parameter, `allow_pickle`, to call `get_model` with
models that are not `.skops` files. This change is to mitigate security risks
associated with pickles. :pr:`485` by `Io_no`_.

v0.12
-----
Expand Down
39 changes: 33 additions & 6 deletions skops/card/_model_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,9 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__}({nrows}x{ncols})"


def _load_model(model: Any, trusted=False) -> Any:
def _load_model(
model: Any, trusted: Optional[Sequence[str]] = None, allow_pickle: bool = False
) -> Any:
"""Return a model instance.

Loads the model if provided a file path, if already a model instance return
Expand All @@ -238,10 +240,14 @@ def _load_model(model: Any, trusted=False) -> Any:
model : pathlib.Path, str, or sklearn estimator
Path/str or the actual model instance. if a Path or str, loads the model.

trusted : bool, default=False
trusted: list of str, default=None
Passed to :func:`skops.io.load` if the model is a file path and it's
a `skops` file.

allow_pickle : bool, default=False
If `True`, allows loading models using `joblib.load`. This may lead to
security issues if the model file is not trustworthy.

Returns
-------
model : object
Expand All @@ -255,13 +261,28 @@ def _load_model(model: Any, trusted=False) -> Any:
if not model_path.exists():
raise FileNotFoundError(f"File is not present: {model_path}")

if trusted and allow_pickle:
raise ValueError(
"`allow_pickle` cannot be `True` if `trusted` is not empty. "
"Pickles cannot be trusted or checked for security issues."
)

msg = ""
try:
if zipfile.is_zipfile(model_path):
model = load(model_path, trusted=trusted)
else:
elif allow_pickle:
model = joblib.load(model_path)
else:
msg = (
"Model file is not a skops file, and allow_pickle is set to False. "
"Please set allow_pickle=True to load the model."
"This may lead to security issues if the model file is not trustworthy."
)
raise RuntimeError(msg)
except Exception as ex:
msg = f'An "{type(ex).__name__}" occurred during model loading.'
if not msg:
msg = f'"{type(ex).__name__}" occurred during model loading.'
raise RuntimeError(msg) from ex

return model
Expand Down Expand Up @@ -310,10 +331,14 @@ class Card:
not work, e.g. :meth:`Card.add_metrics`, since it's not clear where to
put the metrics when there is no template or a custom template.

trusted: bool, default=False
trusted: list of str, default=None
Passed to :func:`skops.io.load` if the model is a file path and it's
a `skops` file.

allow_pickle: bool, default=False
If `True`, allows loading models using `joblib.load`. This may lead to
security issues if the model file is not trustworthy.

Attributes
----------
model: estimator object
Expand Down Expand Up @@ -379,11 +404,13 @@ def __init__(
model_diagram: bool | Literal["auto"] | str = "auto",
template: Literal["skops"] | dict[str, str] | None = "skops",
trusted: Optional[List[str]] = None,
allow_pickle: bool = False,
) -> None:
self.model = model
self.model_format = model_format
self.template = template
self.trusted = trusted
self.allow_pickle = allow_pickle

self._data: dict[str, Section] = {}
self._metrics: dict[str, str | float | int] = {}
Expand Down Expand Up @@ -465,7 +492,7 @@ def get_model(self) -> Any:

@cached_property
def _model(self):
model = _load_model(self.model, self.trusted)
model = _load_model(self.model, self.trusted, self.allow_pickle)
return model

def add(self, folded: bool = False, **kwargs: str) -> Self:
Expand Down
43 changes: 38 additions & 5 deletions skops/card/tests/test_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,18 +82,49 @@ def test_load_model(suffix):
_, save_file = save_model_to_file(model0, suffix)
if suffix == ".skops":
untrusted_types = get_untrusted_types(file=save_file)
allow_pickle = False
else:
untrusted_types = None
loaded_model_str = _load_model(save_file, trusted=untrusted_types)
allow_pickle = True
loaded_model_str = _load_model(
save_file, trusted=untrusted_types, allow_pickle=allow_pickle
)
save_file_path = Path(save_file)
loaded_model_path = _load_model(save_file_path, trusted=untrusted_types)
loaded_model_path = _load_model(
save_file_path, trusted=untrusted_types, allow_pickle=allow_pickle
)
loaded_model_instance = _load_model(model0, trusted=untrusted_types)

assert loaded_model_str.param_1 == 10
assert loaded_model_path.param_1 == 10
assert loaded_model_instance.param_1 == 10


@pytest.mark.parametrize("suffix", [".pkl", ".pickle"])
def test_load_model_exception_allow_pickle(suffix):
model0 = MyRegressor(param_1=10)
_, save_file = save_model_to_file(model0, suffix)

with pytest.raises(
RuntimeError,
match=(
"Model file is not a skops file, and allow_pickle is set to False. "
"Please set allow_pickle=True to load the model."
"This may lead to security issues if the model file is not trustworthy."
),
):
_load_model(save_file, trusted=None, allow_pickle=False)

with pytest.raises(
ValueError,
match=(
"`allow_pickle` cannot be `True` if `trusted` is not empty. "
"Pickles cannot be trusted or checked for security issues."
),
):
_load_model(save_file, trusted=[""], allow_pickle=True)


@pytest.fixture
def model_card(model_diagram=True):
model = fit_model()
Expand Down Expand Up @@ -166,7 +197,9 @@ def test_model_caching(skops_model_card, iris_skops_file, destination_path):

new_model = MyClassifier(param_1=10)
# mock _load_model, it still loads the model but we can track call count
mock_load_model = mock.Mock(side_effect=load)
mock_load_model = mock.Mock(
side_effect=lambda path, trusted, _: load(path, trusted=trusted)
)
card = Card(iris_skops_file, trusted=[MyClassifier])
with mock.patch("skops.card._model_card._load_model", mock_load_model):
model1 = card.get_model()
Expand Down Expand Up @@ -1133,7 +1166,7 @@ def path_to_card(self, path, suffix):
if suffix == ".skops":
card = Card(model=path, trusted=get_untrusted_types(file=path))
else:
card = Card(model=path)
card = Card(model=path, allow_pickle=True)
return card

@pytest.mark.parametrize("meth", [repr, str])
Expand Down Expand Up @@ -1178,7 +1211,7 @@ def test_load_model_exception(self, meth, suffix):
os.close(file_handle)

with pytest.raises(Exception, match="occurred during model loading."):
card = Card(file_name)
card = Card(file_name, allow_pickle=True)
meth(card)

@pytest.mark.parametrize("meth", [repr, str])
Expand Down
Loading