Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,7 @@ node_modules

exports
trash

# Claude
CLAUDE.md
.claude/
3 changes: 3 additions & 0 deletions docs/changes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ skops Changelog

v0.14
-----
- Trust internal scikit-learn types needed by GradientBoosting and
HistGradientBoosting models, so they no longer surface as untrusted types.
:pr:`513` by :user:`cakedev0 <cakedev0>` and `Adrin Jalali`_.

v0.13
-----
Expand Down
1,347 changes: 673 additions & 674 deletions pixi.lock

Large diffs are not rendered by default.

6 changes: 5 additions & 1 deletion skops/io/_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
PRIMITIVE_TYPE_NAMES,
SCIPY_UFUNC_TYPE_NAMES,
SKLEARN_ESTIMATOR_TYPE_NAMES,
SKLEARN_INTERNAL_TYPE_NAMES,
)
from ._utils import (
LoadContext,
Expand Down Expand Up @@ -479,7 +480,10 @@ def __init__(

self.children = {"attrs": attrs}
# TODO: what do we trust?
self.trusted = self._get_trusted(trusted, default=SKLEARN_ESTIMATOR_TYPE_NAMES)
self.trusted = self._get_trusted(
trusted,
default=SKLEARN_ESTIMATOR_TYPE_NAMES + SKLEARN_INTERNAL_TYPE_NAMES,
)

def _construct(self):
cls = gettype(self.module_name, self.class_name)
Expand Down
25 changes: 21 additions & 4 deletions skops/io/_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
CyPinballLoss,
}
except ImportError:
pass
CyHalfMultinomialLoss = None

# This import is for the parent class of all loss functions, which is used to
# set the dispatch function for all loss functions.
Expand Down Expand Up @@ -230,18 +230,30 @@ def loss_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]:
state = reduce_get_state(obj, save_context)
state["__loader__"] = "LossNode"
elif type(obj) == reduce[1][0]:
# the output is of the form:
# The output is commonly of the form:
# >>> CyPinballLoss(1).__reduce__()
# (<cyfunction __pyx_unpickle_CyPinballLoss at 0x7b1d00099ff0>,
# (<class '_loss.CyPinballLoss'>, 232784418, (1.0,)))
#
# CyHalfMultinomialLoss differs slightly and may return a 3-tuple:
# >>> CyHalfMultinomialLoss().__reduce__()
# (<cyfunction __pyx_unpickle_CyHalfMultinomialLoss at 0x...>,
# (<class '_loss.CyHalfMultinomialLoss'>, 238750788, None), ())
#
# In that case, the constructor takes no args and the state lives
# in reduce[2].
state = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
"__loader__": "LossNode",
}
state["__reduce__"] = {}
state["__reduce__"]["args"] = get_state(reduce[1][2], save_context)
state["content"] = get_state({}, save_context)
if len(reduce) == 3:
state["__reduce__"]["args"] = get_state((), save_context)
state["content"] = get_state(reduce[2], save_context)
else:
state["__reduce__"]["args"] = get_state(reduce[1][2], save_context)
state["content"] = get_state({}, save_context)

return state

Expand Down Expand Up @@ -326,6 +338,11 @@ def _construct(self):
if CyLossFunction is not None:
GET_STATE_DISPATCH_FUNCTIONS.append((CyLossFunction, loss_get_state))

# CyHalfMultinomialLoss is not a subclass of CyLossFunction, so it needs its
# own dispatch entry. It's already in ALLOWED_LOSSES so LossNode will trust it.
if CyHalfMultinomialLoss is not None:
GET_STATE_DISPATCH_FUNCTIONS.append((CyHalfMultinomialLoss, loss_get_state))
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't work because CyHalfMultinomialLoss doesn't have the same reduce/state shape as other Cy* losses.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests are passing, can you give me a test case where it fails?

Copy link
Copy Markdown

@cakedev0 cakedev0 Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this loss_get_state, it works (I tired the fix on nighly, 1.7 and 1.6, it works there too):

def loss_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]:
    reduce = obj.__reduce__()
    if type(obj) == reduce[0]:
        state = reduce_get_state(obj, save_context)
        state["__loader__"] = "LossNode"
    elif type(obj) == reduce[1][0]:
        # The output is commonly of the form:
        # >>> CyPinballLoss(1).__reduce__()
        # (<cyfunction __pyx_unpickle_CyPinballLoss at 0x7b1d00099ff0>,
        #             (<class '_loss.CyPinballLoss'>, 232784418, (1.0,)))
        #
        # CyHalfMultinomialLoss differs slightly and returns:
        # >>> CyHalfMultinomialLoss().__reduce__()
        # (<cyfunction __pyx_unpickle_CyHalfMultinomialLoss at 0x...>,
        #  (<class '_loss.CyHalfMultinomialLoss'>, 238750788, None), ())
        #
        # In that case, the constructor takes no args and the tuple state lives
        # in reduce[2].
        state = {
            "__class__": obj.__class__.__name__,
            "__module__": get_module(type(obj)),
            "__loader__": "LossNode",
        }
        state["__reduce__"] = {}
        if len(reduce) == 3:
            state["__reduce__"]["args"] = get_state((), save_context)
            state["content"] = get_state(reduce[2], save_context)
        else:
            state["__reduce__"]["args"] = get_state(reduce[1][2], save_context)
            state["content"] = get_state({}, save_context)

    return state

Copy link
Copy Markdown

@cakedev0 cakedev0 Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests are passing, can you give me a test case where it fails?

Locally (using pixi env: ci-sklearn18), I have one single test that fails: FAILED skops/io/tests/test_persist.py::test_gradient_boosting_estimators_have_no_untrusted_types[GradientBoostingClassifier-log_loss-multiclass] - TypeError: _loss.CyHalfMultinomialLoss() argument after * must be an iterable, not NoneType

The CI has the same failure, but also other failures for other sklearn versions, which are the ones I skipped in my PR with:

@pytest.mark.skipif(
    SKLEARN_VERSION < parse_version("1.4"),
    reason=(
        "Before scikit-learn 1.4, GradientBoosting uses different internal loss "
        "objects (`sklearn.ensemble._gb_losses`), which we don't try to support "
        "as trusted types."
    ),
)


for type_ in UNSUPPORTED_TYPES:
GET_STATE_DISPATCH_FUNCTIONS.append((type_, unsupported_get_state))

Expand Down
64 changes: 64 additions & 0 deletions skops/io/_trusted_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,70 @@
if get_type_name(estimator_class).startswith("sklearn.")
]

# Internal sklearn types used by GradientBoosting and HistGradientBoosting models.
# These are not public estimators but are safe internal types (loss functions, link
# functions, binning, and predictor objects) needed for serialization of fitted models.
_SKLEARN_INTERNAL_TYPES: list[type] = []

try:
from sklearn._loss.link import (
HalfLogitLink,
IdentityLink,
Interval,
LogitLink,
LogLink,
MultinomialLogit,
)

_SKLEARN_INTERNAL_TYPES.extend(
[HalfLogitLink, IdentityLink, Interval, LogitLink, LogLink, MultinomialLogit]
)
except ImportError:
pass

try:
from sklearn._loss.loss import (
AbsoluteError,
ExponentialLoss,
HalfBinomialLoss,
HalfGammaLoss,
HalfMultinomialLoss,
HalfPoissonLoss,
HalfSquaredError,
HuberLoss,
PinballLoss,
)

_SKLEARN_INTERNAL_TYPES.extend(
[
AbsoluteError,
ExponentialLoss,
HalfBinomialLoss,
HalfGammaLoss,
HalfMultinomialLoss,
HalfPoissonLoss,
HalfSquaredError,
HuberLoss,
PinballLoss,
]
)
except ImportError:
pass

try:
from sklearn.ensemble._hist_gradient_boosting.binning import _BinMapper
from sklearn.ensemble._hist_gradient_boosting.predictor import TreePredictor

_SKLEARN_INTERNAL_TYPES.extend([_BinMapper, TreePredictor])
except ImportError:
pass

SKLEARN_INTERNAL_TYPE_NAMES = [
get_type_name(t)
for t in _SKLEARN_INTERNAL_TYPES
if get_type_name(t).startswith("sklearn.")
]

with warnings.catch_warnings():
# This is to suppress deprecation warning coming from the fact that scipy reports
# numpy.core for ufuncs, and numpy.core is deprecated and renamed to numpy._core
Expand Down
113 changes: 113 additions & 0 deletions skops/io/tests/test_persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,19 @@
import joblib
import numpy as np
import pytest
import sklearn
from scipy import sparse, special
from sklearn.base import BaseEstimator, is_regressor
from sklearn.compose import ColumnTransformer
from sklearn.datasets import load_sample_images, make_classification, make_regression
from sklearn.decomposition import SparseCoder
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis
from sklearn.ensemble import (
GradientBoostingClassifier,
GradientBoostingRegressor,
HistGradientBoostingClassifier,
HistGradientBoostingRegressor,
)
from sklearn.exceptions import SkipTestWarning
from sklearn.experimental import enable_halving_search_cv # noqa
from sklearn.feature_extraction.text import TfidfVectorizer
Expand Down Expand Up @@ -438,6 +445,112 @@ def test_can_trust_types(type_):
assert len(untrusted_types) == 0


@pytest.mark.skipif(
parse_version(sklearn.__version__) < parse_version("1.4"),
reason=(
"Before scikit-learn 1.4, GradientBoosting uses different internal loss "
"objects (sklearn.ensemble._gb_losses) that are not supported as trusted types."
),
)
@pytest.mark.parametrize(
("estimator", "problem_type"),
[
pytest.param(
GradientBoostingClassifier(loss="log_loss", n_estimators=5),
"multiclass",
id="GradientBoostingClassifier-log_loss-multiclass",
),
pytest.param(
GradientBoostingClassifier(loss="exponential", n_estimators=5),
"binary",
id="GradientBoostingClassifier-exponential",
),
pytest.param(
GradientBoostingRegressor(loss="squared_error", n_estimators=5),
"regression",
id="GradientBoostingRegressor-squared_error",
),
pytest.param(
GradientBoostingRegressor(loss="absolute_error", n_estimators=5),
"regression",
id="GradientBoostingRegressor-absolute_error",
),
pytest.param(
GradientBoostingRegressor(loss="huber", n_estimators=5),
"regression",
id="GradientBoostingRegressor-huber",
),
pytest.param(
GradientBoostingRegressor(loss="quantile", n_estimators=5, alpha=0.8),
"regression",
id="GradientBoostingRegressor-quantile",
),
pytest.param(
HistGradientBoostingClassifier(loss="log_loss", max_iter=5),
"binary",
id="HistGradientBoostingClassifier-log_loss",
),
pytest.param(
HistGradientBoostingRegressor(loss="gamma", max_iter=5),
"positive_regression",
id="HistGradientBoostingRegressor-gamma",
),
pytest.param(
HistGradientBoostingRegressor(loss="poisson", max_iter=5),
"positive_regression",
id="HistGradientBoostingRegressor-poisson",
),
],
)
def test_gradient_boosting_estimators_have_no_untrusted_types(estimator, problem_type):
"""Fitted GB/HGB models should save and load without any untrusted types,
even though they contain non-public sklearn internals (loss functions, link
functions, Cython loss classes, _BinMapper, TreePredictor)."""
set_random_state(estimator, random_state=0)

if problem_type == "binary":
X, y = make_classification(
n_samples=N_SAMPLES,
n_features=N_FEATURES,
n_classes=2,
n_informative=5,
random_state=0,
)
elif problem_type == "multiclass":
# n_samples must be > n_classes * n_clusters_per_class to avoid errors
X, y = make_classification(
n_samples=140,
n_features=N_FEATURES,
n_classes=3,
n_informative=8,
n_clusters_per_class=1,
random_state=0,
)
elif problem_type == "positive_regression":
X, y = make_regression(
n_samples=N_SAMPLES,
n_features=N_FEATURES,
random_state=0,
)
y = np.abs(y) + 1
else:
X, y = make_regression(
n_samples=N_SAMPLES,
n_features=N_FEATURES,
random_state=0,
)

estimator.fit(X, y)

dumped = dumps(estimator)
untrusted_types = get_untrusted_types(data=dumped)
assert untrusted_types == []

loaded = loads(dumped)
assert_params_equal(estimator.__dict__, loaded.__dict__)
assert_method_outputs_equal(estimator, loaded, X)


@pytest.mark.parametrize(
"estimator", _unsupported_estimators(), ids=_get_check_estimator_ids
)
Expand Down