From a30f71ba667ec92996ce4d9f4c6523b4435d81de Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Thu, 24 Jul 2025 12:42:20 +0200 Subject: [PATCH 1/9] ENH harden Method and Operator node audits --- skops/io/_general.py | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/skops/io/_general.py b/skops/io/_general.py index 07cdfd53..280ac767 100644 --- a/skops/io/_general.py +++ b/skops/io/_general.py @@ -529,13 +529,37 @@ def __init__( trusted: Optional[Sequence[str]] = None, ) -> None: super().__init__(state, load_context, trusted) + obj = get_tree(state["content"]["obj"], load_context, trusted=trusted) + if self.class_name != obj.class_name or self.module_name != obj.module_name: + raise ValueError( + f"Expected object of type {self.class_name}.{self.module_name}, got" + f" {obj.class_name}.{obj.module_name}. This is probably due to a" + " corrupted or a malicious file." + ) self.children = { - "obj": get_tree(state["content"]["obj"], load_context, trusted=trusted), + "obj": obj, "func": state["content"]["func"], } # TODO: what do we trust? self.trusted = self._get_trusted(trusted, []) + def get_unsafe_set(self) -> set[str]: + res = super().get_unsafe_set() + obj_node = self.children["obj"] + if not hasattr(obj_node, "module_name") or not hasattr(obj_node, "class_name"): + raise ValueError( + "MethodNode must have an object node as child. This is probably due to" + " a corrupted or a malicious file." + ) + res.add( + obj_node.module_name # type: ignore + + "." + + obj_node.class_name # type: ignore + + "." + + self.children["func"] + ) + return res + def _construct(self): loaded_obj = self.children["obj"].construct() method = getattr(loaded_obj, self.children["func"]) @@ -658,6 +682,11 @@ def __init__( trusted: Optional[Sequence[str]] = None, ) -> None: super().__init__(state, load_context, trusted) + if self.module_name != "operator": + raise ValueError( + f"Expected module 'operator', got {self.module_name}. This is probably" + " due to a corrupted or a malicious file." + ) self.trusted = self._get_trusted(trusted, []) self.children["attrs"] = get_tree(state["attrs"], load_context, trusted=trusted) From a65101230e4cd481083f4c6db1007aaec643a21a Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Thu, 24 Jul 2025 13:51:41 +0200 Subject: [PATCH 2/9] ... --- pyproject.toml | 1 + skops/io/_audit.py | 2 +- skops/io/_general.py | 14 ++++++++------ 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f3b69877..fac38d71 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -90,6 +90,7 @@ filterwarnings = [ "ignore:The ExtraTreesQuantileRegressor or classes from which it inherits use `_get_tags` and `_more_tags`:FutureWarning", # BaseEstimator._validate_data deprecation warning in sklearn 1.6 #TODO can be removed when a new release of quantile-forest is out "ignore:`BaseEstimator._validate_data` is deprecated in 1.6 and will be removed in 1.7:FutureWarning", + "ignore:'mode' parameter is deprecated and will be removed in Pillow 13:DeprecationWarning", ] addopts = "--cov=skops --cov-report=term-missing --doctest-modules" diff --git a/skops/io/_audit.py b/skops/io/_audit.py index ac936627..8aa1a1f4 100644 --- a/skops/io/_audit.py +++ b/skops/io/_audit.py @@ -266,7 +266,7 @@ def get_unsafe_set(self) -> set[str]: " for us to fix the issue." ) - return res + return res - set(self.trusted) def format(self) -> str: """Representation of the node's content.""" diff --git a/skops/io/_general.py b/skops/io/_general.py index 280ac767..2f72ff82 100644 --- a/skops/io/_general.py +++ b/skops/io/_general.py @@ -509,13 +509,15 @@ def method_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: # dependent on a specific instance of an object. # It stores the state of the object the method is bound to, # and prepares both to be persisted. + owner = obj.__self__ + func_name = obj.__func__.__name__ res = { - "__class__": obj.__class__.__name__, + "__class__": owner.__class__.__name__, "__module__": get_module(obj), "__loader__": "MethodNode", "content": { - "func": obj.__func__.__name__, - "obj": get_state(obj.__self__, save_context), + "func": func_name, + "obj": get_state(owner, save_context), }, } return res @@ -530,10 +532,10 @@ def __init__( ) -> None: super().__init__(state, load_context, trusted) obj = get_tree(state["content"]["obj"], load_context, trusted=trusted) - if self.class_name != obj.class_name or self.module_name != obj.module_name: + if self.module_name != obj.module_name or self.class_name != obj.class_name: raise ValueError( - f"Expected object of type {self.class_name}.{self.module_name}, got" - f" {obj.class_name}.{obj.module_name}. This is probably due to a" + f"Expected object of type {self.module_name}.{self.class_name}, got" + f" {obj.module_name}.{obj.class_name}. This is probably due to a" " corrupted or a malicious file." ) self.children = { From a926d9708226966394ca3beacdfc9ac02b583b13 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Thu, 24 Jul 2025 14:03:22 +0200 Subject: [PATCH 3/9] ... --- skops/io/_audit.py | 9 +++++---- skops/io/_persist.py | 4 ++-- skops/io/tests/test_audit.py | 8 ++++---- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/skops/io/_audit.py b/skops/io/_audit.py index 8aa1a1f4..d9e7a03b 100644 --- a/skops/io/_audit.py +++ b/skops/io/_audit.py @@ -2,7 +2,7 @@ import io from contextlib import contextmanager -from typing import Any, Dict, Generator, List, Optional, Sequence, Type, Union +from typing import Any, Dict, Generator, Iterable, List, Optional, Sequence, Type, Union from ._protocol import PROTOCOL from ._utils import LoadContext, get_module, get_type_paths @@ -39,7 +39,7 @@ def check_type(module_name: str, type_name: str, trusted: Sequence[str]) -> bool return module_name + "." + type_name in trusted -def audit_tree(tree: Node) -> None: +def audit_tree(tree: Node, trusted: Iterable[str] | None) -> None: """Audit a tree of nodes. A tree is safe if it only contains trusted types. @@ -54,7 +54,8 @@ def audit_tree(tree: Node) -> None: UntrustedTypesFoundException If the tree contains an untrusted type. """ - unsafe = tree.get_unsafe_set() + trusted = trusted or set() + unsafe = tree.get_unsafe_set() - set(trusted) if unsafe: raise UntrustedTypesFoundException(unsafe) @@ -266,7 +267,7 @@ def get_unsafe_set(self) -> set[str]: " for us to fix the issue." ) - return res - set(self.trusted) + return res def format(self) -> str: """Representation of the node's content.""" diff --git a/skops/io/_persist.py b/skops/io/_persist.py index aaed469c..6dd00e2f 100644 --- a/skops/io/_persist.py +++ b/skops/io/_persist.py @@ -148,7 +148,7 @@ def load(file: str | Path, trusted: Optional[Sequence[str]] = None) -> Any: schema = json.loads(input_zip.read("schema.json")) load_context = LoadContext(src=input_zip, protocol=schema["protocol"]) tree = get_tree(schema, load_context, trusted=trusted) - audit_tree(tree) + audit_tree(tree, trusted=trusted) instance = tree.construct() return instance @@ -188,7 +188,7 @@ def loads(data: bytes, trusted: Optional[Sequence[str]] = None) -> Any: schema = json.loads(zip_file.read("schema.json")) load_context = LoadContext(src=zip_file, protocol=schema["protocol"]) tree = get_tree(schema, load_context, trusted=trusted) - audit_tree(tree) + audit_tree(tree, trusted=trusted) instance = tree.construct() return instance diff --git a/skops/io/tests/test_audit.py b/skops/io/tests/test_audit.py index 0e6514c3..903302bd 100644 --- a/skops/io/tests/test_audit.py +++ b/skops/io/tests/test_audit.py @@ -46,18 +46,18 @@ def test_audit_tree_untrusted(): "Untrusted types found in the file: ['test_audit.CustomType']." ), ): - audit_tree(node) + audit_tree(node, None) # there shouldn't be an error with trusted=everything node = DictNode(state, LoadContext(None, -1), trusted=["test_audit.CustomType"]) - audit_tree(node) + audit_tree(node, None) untrusted_list = get_untrusted_types(data=dumps(var)) assert untrusted_list == ["test_audit.CustomType"] # passing the type would fix it. node = DictNode(state, LoadContext(None, -1), trusted=untrusted_list) - audit_tree(node) + audit_tree(node, None) def test_audit_tree_defaults(): @@ -65,7 +65,7 @@ def test_audit_tree_defaults(): var = {"a": 1, 2: "b"} state = dict_get_state(var, SaveContext(None, 0, {})) node = DictNode(state, LoadContext(None, -1), trusted=None) - audit_tree(node) + audit_tree(node, None) @pytest.mark.parametrize( From 4ba94e7635ed68e93fc723131190a145ea414894 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Thu, 24 Jul 2025 14:07:08 +0200 Subject: [PATCH 4/9] trigger CI From 91ee39807d78f7656be5221b324192c930e10c4c Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Thu, 24 Jul 2025 15:45:25 +0200 Subject: [PATCH 5/9] add tests --- pyproject.toml | 1 + skops/io/_general.py | 5 ----- skops/io/tests/test_audit.py | 35 ++++++++++++++++++++++++++++++++++- 3 files changed, 35 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index fac38d71..f40fae27 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -90,6 +90,7 @@ filterwarnings = [ "ignore:The ExtraTreesQuantileRegressor or classes from which it inherits use `_get_tags` and `_more_tags`:FutureWarning", # BaseEstimator._validate_data deprecation warning in sklearn 1.6 #TODO can be removed when a new release of quantile-forest is out "ignore:`BaseEstimator._validate_data` is deprecated in 1.6 and will be removed in 1.7:FutureWarning", + # This comes from matplotlib somehow "ignore:'mode' parameter is deprecated and will be removed in Pillow 13:DeprecationWarning", ] addopts = "--cov=skops --cov-report=term-missing --doctest-modules" diff --git a/skops/io/_general.py b/skops/io/_general.py index 2f72ff82..a8715a71 100644 --- a/skops/io/_general.py +++ b/skops/io/_general.py @@ -548,11 +548,6 @@ def __init__( def get_unsafe_set(self) -> set[str]: res = super().get_unsafe_set() obj_node = self.children["obj"] - if not hasattr(obj_node, "module_name") or not hasattr(obj_node, "class_name"): - raise ValueError( - "MethodNode must have an object node as child. This is probably due to" - " a corrupted or a malicious file." - ) res.add( obj_node.module_name # type: ignore + "." diff --git a/skops/io/tests/test_audit.py b/skops/io/tests/test_audit.py index 903302bd..291a409c 100644 --- a/skops/io/tests/test_audit.py +++ b/skops/io/tests/test_audit.py @@ -1,15 +1,26 @@ import io import json +import operator import re from contextlib import suppress from zipfile import ZipFile import pytest from sklearn.linear_model import LogisticRegression +from sklearn.preprocessing import FunctionTransformer from skops.io import dumps, get_untrusted_types from skops.io._audit import Node, audit_tree, check_type, get_tree, temp_setattr -from skops.io._general import DictNode, JsonNode, ObjectNode, dict_get_state +from skops.io._general import ( + DictNode, + JsonNode, + MethodNode, + ObjectNode, + OperatorFuncNode, + dict_get_state, + method_get_state, + operator_func_get_state, +) from skops.io._utils import LoadContext, SaveContext, get_state, gettype @@ -170,3 +181,25 @@ def test_format_json_node(inp, expected): state = get_state(inp, SaveContext(None)) node = JsonNode(state, LoadContext(None, -1)) assert node.format() == expected + + +def test_method_node_invalid_state(): + # Test that MethodNode raises a ValueError if the state is invalid. + # The __class__ and __module__ should match what's inside the content. + var = FunctionTransformer().fit + state = method_get_state(var, SaveContext(None, 0, {})) + state["content"]["obj"]["__class__"] = "foo" + load_context = LoadContext(None, -1) + + with pytest.raises(ValueError, match="Expected object of type"): + MethodNode(state, load_context, trusted=None) + + +def test_operator_func_node_invalid_state(): + var = operator.methodcaller("fit") + state = operator_func_get_state(var, SaveContext(None, 0, {})) + state["__module__"] = "foo" + load_context = LoadContext(None, -1) + + with pytest.raises(ValueError, match="Expected module 'operator'"): + OperatorFuncNode(state, load_context, trusted=None) From d6a7f0f5daa4011135cb590e4fbbb3eaa4e60d33 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Fri, 25 Jul 2025 12:57:36 +0200 Subject: [PATCH 6/9] CI timeout --- .github/workflows/build-test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index 2674cc1c..fed447b2 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -27,7 +27,7 @@ jobs: ] # Timeout: https://stackoverflow.com/a/59076067/4521646 - timeout-minutes: 15 + timeout-minutes: 30 steps: # The following two steps are workarounds to retrieve the "real" commit From 38cc7a9f30a4834c7ee96e7ed62792907c5d430d Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Fri, 25 Jul 2025 13:46:40 +0200 Subject: [PATCH 7/9] changelog --- docs/changes.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/changes.rst b/docs/changes.rst index c0ddb20e..2c6e4e3f 100644 --- a/docs/changes.rst +++ b/docs/changes.rst @@ -11,6 +11,12 @@ skops Changelog v0.12 ----- +- `huggingface_hub` dependency is now optional. :pr:`462` by `Adrin Jalali`_. +- Objects' `__reduce__` is used when the output of it is of the form + `(type, (constructor_args,)` where type is the same as the `type(obj)`. + :pr:`467` by `Adrin Jalali`_. +- `MethodNode` and `OperatorNode` have a hardened audit now, removing certain security + vulnerabilities. :pr:`482` by `Adrin Jalali`_. v0.11 ----- From 727e1a38693dc66f8a5f0df7614362ff7acc6365 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Fri, 25 Jul 2025 13:53:03 +0200 Subject: [PATCH 8/9] codecov count change --- .codecov.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.codecov.yml b/.codecov.yml index 26aa1c1a..1d054b41 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -3,7 +3,7 @@ codecov: branch: main require_ci_to_pass: true notify: - after_n_builds: 12 + after_n_builds: 21 wait_for_ci: true ignore: - "skops/_min_dependencies.py" # This file is not tested, and won't be. From ec83dcf8487d0a1062e6ec3067f4420a0b5faa71 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Fri, 25 Jul 2025 13:58:23 +0200 Subject: [PATCH 9/9] trigger CI