Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ codecov:
branch: main
require_ci_to_pass: true
notify:
after_n_builds: 12
after_n_builds: 21
wait_for_ci: true
ignore:
- "skops/_min_dependencies.py" # This file is not tested, and won't be.
2 changes: 1 addition & 1 deletion .github/workflows/build-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
]

# Timeout: https://stackoverflow.com/a/59076067/4521646
timeout-minutes: 15
timeout-minutes: 30

steps:
# The following two steps are workarounds to retrieve the "real" commit
Expand Down
6 changes: 6 additions & 0 deletions docs/changes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ skops Changelog

v0.12
-----
- `huggingface_hub` dependency is now optional. :pr:`462` by `Adrin Jalali`_.
- Objects' `__reduce__` is used when the output of it is of the form
`(type, (constructor_args,)` where type is the same as the `type(obj)`.
:pr:`467` by `Adrin Jalali`_.
- `MethodNode` and `OperatorNode` have a hardened audit now, removing certain security
vulnerabilities. :pr:`482` by `Adrin Jalali`_.

v0.11
-----
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ filterwarnings = [
"ignore:The ExtraTreesQuantileRegressor or classes from which it inherits use `_get_tags` and `_more_tags`:FutureWarning",
# BaseEstimator._validate_data deprecation warning in sklearn 1.6 #TODO can be removed when a new release of quantile-forest is out
"ignore:`BaseEstimator._validate_data` is deprecated in 1.6 and will be removed in 1.7:FutureWarning",
# This comes from matplotlib somehow
"ignore:'mode' parameter is deprecated and will be removed in Pillow 13:DeprecationWarning",
]
addopts = "--cov=skops --cov-report=term-missing --doctest-modules"

Expand Down
7 changes: 4 additions & 3 deletions skops/io/_audit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import io
from contextlib import contextmanager
from typing import Any, Dict, Generator, List, Optional, Sequence, Type, Union
from typing import Any, Dict, Generator, Iterable, List, Optional, Sequence, Type, Union

from ._protocol import PROTOCOL
from ._utils import LoadContext, get_module, get_type_paths
Expand Down Expand Up @@ -39,7 +39,7 @@ def check_type(module_name: str, type_name: str, trusted: Sequence[str]) -> bool
return module_name + "." + type_name in trusted


def audit_tree(tree: Node) -> None:
def audit_tree(tree: Node, trusted: Iterable[str] | None) -> None:
"""Audit a tree of nodes.

A tree is safe if it only contains trusted types.
Expand All @@ -54,7 +54,8 @@ def audit_tree(tree: Node) -> None:
UntrustedTypesFoundException
If the tree contains an untrusted type.
"""
unsafe = tree.get_unsafe_set()
trusted = trusted or set()
unsafe = tree.get_unsafe_set() - set(trusted)
if unsafe:
raise UntrustedTypesFoundException(unsafe)

Expand Down
34 changes: 30 additions & 4 deletions skops/io/_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,13 +509,15 @@ def method_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]:
# dependent on a specific instance of an object.
# It stores the state of the object the method is bound to,
# and prepares both to be persisted.
owner = obj.__self__
func_name = obj.__func__.__name__
res = {
"__class__": obj.__class__.__name__,
"__class__": owner.__class__.__name__,
"__module__": get_module(obj),
"__loader__": "MethodNode",
"content": {
"func": obj.__func__.__name__,
"obj": get_state(obj.__self__, save_context),
"func": func_name,
"obj": get_state(owner, save_context),
},
}
return res
Expand All @@ -529,13 +531,32 @@ def __init__(
trusted: Optional[Sequence[str]] = None,
) -> None:
super().__init__(state, load_context, trusted)
obj = get_tree(state["content"]["obj"], load_context, trusted=trusted)
if self.module_name != obj.module_name or self.class_name != obj.class_name:
raise ValueError(
f"Expected object of type {self.module_name}.{self.class_name}, got"
f" {obj.module_name}.{obj.class_name}. This is probably due to a"
" corrupted or a malicious file."
)
self.children = {
"obj": get_tree(state["content"]["obj"], load_context, trusted=trusted),
"obj": obj,
"func": state["content"]["func"],
}
# TODO: what do we trust?
self.trusted = self._get_trusted(trusted, [])

def get_unsafe_set(self) -> set[str]:
res = super().get_unsafe_set()
obj_node = self.children["obj"]
res.add(
obj_node.module_name # type: ignore
+ "."
+ obj_node.class_name # type: ignore
+ "."
+ self.children["func"]
)
return res

def _construct(self):
loaded_obj = self.children["obj"].construct()
method = getattr(loaded_obj, self.children["func"])
Expand Down Expand Up @@ -658,6 +679,11 @@ def __init__(
trusted: Optional[Sequence[str]] = None,
) -> None:
super().__init__(state, load_context, trusted)
if self.module_name != "operator":
raise ValueError(
f"Expected module 'operator', got {self.module_name}. This is probably"
" due to a corrupted or a malicious file."
)
self.trusted = self._get_trusted(trusted, [])
self.children["attrs"] = get_tree(state["attrs"], load_context, trusted=trusted)

Expand Down
4 changes: 2 additions & 2 deletions skops/io/_persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def load(file: str | Path, trusted: Optional[Sequence[str]] = None) -> Any:
schema = json.loads(input_zip.read("schema.json"))
load_context = LoadContext(src=input_zip, protocol=schema["protocol"])
tree = get_tree(schema, load_context, trusted=trusted)
audit_tree(tree)
audit_tree(tree, trusted=trusted)
instance = tree.construct()

return instance
Expand Down Expand Up @@ -188,7 +188,7 @@ def loads(data: bytes, trusted: Optional[Sequence[str]] = None) -> Any:
schema = json.loads(zip_file.read("schema.json"))
load_context = LoadContext(src=zip_file, protocol=schema["protocol"])
tree = get_tree(schema, load_context, trusted=trusted)
audit_tree(tree)
audit_tree(tree, trusted=trusted)
instance = tree.construct()

return instance
Expand Down
43 changes: 38 additions & 5 deletions skops/io/tests/test_audit.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,26 @@
import io
import json
import operator
import re
from contextlib import suppress
from zipfile import ZipFile

import pytest
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import FunctionTransformer

from skops.io import dumps, get_untrusted_types
from skops.io._audit import Node, audit_tree, check_type, get_tree, temp_setattr
from skops.io._general import DictNode, JsonNode, ObjectNode, dict_get_state
from skops.io._general import (
DictNode,
JsonNode,
MethodNode,
ObjectNode,
OperatorFuncNode,
dict_get_state,
method_get_state,
operator_func_get_state,
)
from skops.io._utils import LoadContext, SaveContext, get_state, gettype


Expand Down Expand Up @@ -46,26 +57,26 @@ def test_audit_tree_untrusted():
"Untrusted types found in the file: ['test_audit.CustomType']."
),
):
audit_tree(node)
audit_tree(node, None)

# there shouldn't be an error with trusted=everything
node = DictNode(state, LoadContext(None, -1), trusted=["test_audit.CustomType"])
audit_tree(node)
audit_tree(node, None)

untrusted_list = get_untrusted_types(data=dumps(var))
assert untrusted_list == ["test_audit.CustomType"]

# passing the type would fix it.
node = DictNode(state, LoadContext(None, -1), trusted=untrusted_list)
audit_tree(node)
audit_tree(node, None)


def test_audit_tree_defaults():
# test that the default types are trusted
var = {"a": 1, 2: "b"}
state = dict_get_state(var, SaveContext(None, 0, {}))
node = DictNode(state, LoadContext(None, -1), trusted=None)
audit_tree(node)
audit_tree(node, None)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -170,3 +181,25 @@ def test_format_json_node(inp, expected):
state = get_state(inp, SaveContext(None))
node = JsonNode(state, LoadContext(None, -1))
assert node.format() == expected


def test_method_node_invalid_state():
# Test that MethodNode raises a ValueError if the state is invalid.
# The __class__ and __module__ should match what's inside the content.
var = FunctionTransformer().fit
state = method_get_state(var, SaveContext(None, 0, {}))
state["content"]["obj"]["__class__"] = "foo"
load_context = LoadContext(None, -1)

with pytest.raises(ValueError, match="Expected object of type"):
MethodNode(state, load_context, trusted=None)


def test_operator_func_node_invalid_state():
var = operator.methodcaller("fit")
state = operator_func_get_state(var, SaveContext(None, 0, {}))
state["__module__"] = "foo"
load_context = LoadContext(None, -1)

with pytest.raises(ValueError, match="Expected module 'operator'"):
OperatorFuncNode(state, load_context, trusted=None)
Loading