-
Notifications
You must be signed in to change notification settings - Fork 61
Trust internal scikit-learn types needed for GB/HGB models #508
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,7 +6,7 @@ | |
| from sklearn.tree._tree import Tree | ||
|
|
||
| from ._audit import Node, get_tree | ||
| from ._general import TypeNode, unsupported_get_state | ||
| from ._general import ObjectNode, TypeNode, object_get_state, unsupported_get_state | ||
| from ._protocol import PROTOCOL | ||
| from ._utils import LoadContext, SaveContext, get_module, get_state, gettype | ||
| from .exceptions import UnsupportedTypeException | ||
|
|
@@ -97,9 +97,94 @@ | |
| LossFunction = None | ||
|
|
||
|
|
||
| SKLEARN_INTERNAL_OBJECTS: set[type] = set() | ||
| SKLEARN_TYPE_NAME_OVERRIDES: dict[type, str] = {} | ||
|
|
||
| try: | ||
| from sklearn._loss.link import ( | ||
| HalfLogitLink, | ||
| IdentityLink, | ||
| Interval, | ||
| LogitLink, | ||
| LogLink, | ||
| MultinomialLogit, | ||
| ) | ||
|
|
||
| SKLEARN_INTERNAL_OBJECTS |= { | ||
| HalfLogitLink, | ||
| IdentityLink, | ||
| Interval, | ||
| LogLink, | ||
| LogitLink, | ||
| MultinomialLogit, | ||
| } | ||
| except ImportError: | ||
| pass | ||
|
|
||
| try: | ||
| from sklearn._loss.loss import ( | ||
| AbsoluteError, | ||
| ExponentialLoss, | ||
| HalfBinomialLoss, | ||
| HalfGammaLoss, | ||
| HalfMultinomialLoss, | ||
| HalfPoissonLoss, | ||
| HalfSquaredError, | ||
| HuberLoss, | ||
| PinballLoss, | ||
| ) | ||
|
|
||
| SKLEARN_INTERNAL_OBJECTS |= { | ||
| AbsoluteError, | ||
| ExponentialLoss, | ||
| HalfBinomialLoss, | ||
| HalfGammaLoss, | ||
| HalfMultinomialLoss, | ||
| HalfPoissonLoss, | ||
| HalfSquaredError, | ||
| HuberLoss, | ||
| PinballLoss, | ||
| } | ||
| except ImportError: | ||
| pass | ||
|
|
||
| if "CyHalfMultinomialLoss" in globals(): | ||
| SKLEARN_INTERNAL_OBJECTS.add(CyHalfMultinomialLoss) | ||
| SKLEARN_TYPE_NAME_OVERRIDES[CyHalfMultinomialLoss] = ( | ||
| "sklearn._loss._loss.CyHalfMultinomialLoss" | ||
| ) | ||
|
|
||
| try: | ||
| from sklearn.ensemble._hist_gradient_boosting.binning import _BinMapper | ||
| from sklearn.ensemble._hist_gradient_boosting.predictor import TreePredictor | ||
|
|
||
| SKLEARN_INTERNAL_OBJECTS |= {_BinMapper, TreePredictor} | ||
| except ImportError: | ||
| pass | ||
|
|
||
|
|
||
| UNSUPPORTED_TYPES = {Birch} | ||
|
|
||
|
|
||
| def get_sklearn_internal_type_name(type_: type) -> str: | ||
| return SKLEARN_TYPE_NAME_OVERRIDES.get( | ||
| type_, get_module(type_) + "." + type_.__name__ | ||
| ) | ||
|
|
||
|
|
||
| TRUSTED_SKLEARN_INTERNAL_TYPE_NAMES = [ | ||
| get_sklearn_internal_type_name(type_) for type_ in SKLEARN_INTERNAL_OBJECTS | ||
| ] | ||
|
|
||
| if not all( | ||
| type_name.startswith("sklearn.") | ||
| for type_name in TRUSTED_SKLEARN_INTERNAL_TYPE_NAMES | ||
| ): | ||
| raise RuntimeError( | ||
| "All trusted sklearn internal type names must start with 'sklearn.'." | ||
| ) | ||
|
|
||
|
|
||
| def reduce_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: | ||
| # This method is for objects for which we have to use the __reduce__ | ||
| # method to get the state. | ||
|
|
@@ -265,6 +350,33 @@ def __init__( | |
| ) | ||
|
|
||
|
|
||
| def sklearn_internal_object_get_state( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we don't need to add a new node to support these objects. They can be simply trusted in |
||
| obj: Any, save_context: SaveContext | ||
| ) -> dict[str, Any]: | ||
| state = object_get_state(obj, save_context) | ||
| module_name, _, class_name = get_sklearn_internal_type_name(type(obj)).rpartition( | ||
| "." | ||
| ) | ||
| state["__module__"] = module_name | ||
| state["__class__"] = class_name | ||
| state["__loader__"] = "SklearnInternalObjectNode" | ||
| return state | ||
|
|
||
|
|
||
| class SklearnInternalObjectNode(ObjectNode): | ||
| def __init__( | ||
| self, | ||
| state: dict[str, Any], | ||
| load_context: LoadContext, | ||
| trusted: Optional[Sequence[str]] = None, | ||
| ) -> None: | ||
| super().__init__(state, load_context, trusted) | ||
| self.trusted = self._get_trusted( | ||
| trusted, | ||
| default=TRUSTED_SKLEARN_INTERNAL_TYPE_NAMES, | ||
| ) | ||
|
|
||
|
|
||
| # TODO: remove once support for sklearn<1.2 is dropped. | ||
| def _DictWithDeprecatedKeys_get_state( | ||
| obj: Any, save_context: SaveContext | ||
|
|
@@ -326,12 +438,16 @@ def _construct(self): | |
| if CyLossFunction is not None: | ||
| GET_STATE_DISPATCH_FUNCTIONS.append((CyLossFunction, loss_get_state)) | ||
|
|
||
| for type_ in SKLEARN_INTERNAL_OBJECTS: | ||
| GET_STATE_DISPATCH_FUNCTIONS.append((type_, sklearn_internal_object_get_state)) | ||
|
|
||
| for type_ in UNSUPPORTED_TYPES: | ||
| GET_STATE_DISPATCH_FUNCTIONS.append((type_, unsupported_get_state)) | ||
|
|
||
| # tuples of type and function that creates the instance of that type | ||
| NODE_TYPE_MAPPING: dict[tuple[str, int], Any] = { | ||
| ("LossNode", PROTOCOL): LossNode, | ||
| ("SklearnInternalObjectNode", PROTOCOL): SklearnInternalObjectNode, | ||
| ("TreeNode", PROTOCOL): TreeNode, | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,12 +14,19 @@ | |
| import joblib | ||
| import numpy as np | ||
| import pytest | ||
| import sklearn | ||
| from scipy import sparse, special | ||
| from sklearn.base import BaseEstimator, is_regressor | ||
| from sklearn.compose import ColumnTransformer | ||
| from sklearn.datasets import load_sample_images, make_classification, make_regression | ||
| from sklearn.decomposition import SparseCoder | ||
| from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis | ||
| from sklearn.ensemble import ( | ||
| GradientBoostingClassifier, | ||
| GradientBoostingRegressor, | ||
| HistGradientBoostingClassifier, | ||
| HistGradientBoostingRegressor, | ||
| ) | ||
| from sklearn.exceptions import SkipTestWarning | ||
| from sklearn.experimental import enable_halving_search_cv # noqa | ||
| from sklearn.feature_extraction.text import TfidfVectorizer | ||
|
|
@@ -72,6 +79,7 @@ | |
| # Default settings for X | ||
| N_SAMPLES = 120 | ||
| N_FEATURES = 20 | ||
| SKLEARN_VERSION = parse_version(sklearn.__version__) | ||
|
|
||
|
|
||
| @pytest.fixture(autouse=True, scope="module") | ||
|
|
@@ -438,6 +446,153 @@ def test_can_trust_types(type_): | |
| assert len(untrusted_types) == 0 | ||
|
|
||
|
|
||
| @pytest.mark.skipif( | ||
| SKLEARN_VERSION < parse_version("1.4"), | ||
| reason=( | ||
| "Before scikit-learn 1.4, GradientBoosting uses different internal loss " | ||
| "objects (`sklearn.ensemble._gb_losses`), which we don't try to support " | ||
| "as trusted types." | ||
| ), | ||
| ) | ||
| @pytest.mark.parametrize( | ||
| ("estimator", "problem_type"), | ||
| [ | ||
| pytest.param( | ||
| GradientBoostingClassifier(loss="log_loss", n_estimators=5), | ||
| "multiclass", | ||
| id="GradientBoostingClassifier-log_loss-multiclass", | ||
| ), | ||
| pytest.param( | ||
| GradientBoostingClassifier(loss="exponential", n_estimators=5), | ||
| "binary", | ||
| id="GradientBoostingClassifier-exponential", | ||
| ), | ||
| pytest.param( | ||
| GradientBoostingRegressor(loss="squared_error", n_estimators=5), | ||
| "regression", | ||
| id="GradientBoostingRegressor-squared_error", | ||
| ), | ||
| pytest.param( | ||
| GradientBoostingRegressor(loss="absolute_error", n_estimators=5), | ||
| "regression", | ||
| id="GradientBoostingRegressor-absolute_error", | ||
| ), | ||
| pytest.param( | ||
| GradientBoostingRegressor(loss="huber", n_estimators=5), | ||
| "regression", | ||
| id="GradientBoostingRegressor-huber", | ||
| ), | ||
| pytest.param( | ||
| GradientBoostingRegressor(loss="quantile", n_estimators=5, alpha=0.8), | ||
| "regression", | ||
| id="GradientBoostingRegressor-quantile", | ||
| ), | ||
| pytest.param( | ||
| HistGradientBoostingClassifier(loss="log_loss", max_iter=5), | ||
| "binary", | ||
| id="HistGradientBoostingClassifier-log_loss", | ||
| ), | ||
| pytest.param( | ||
| HistGradientBoostingRegressor(loss="gamma", max_iter=5), | ||
| "positive_regression", | ||
| id="HistGradientBoostingRegressor-gamma", | ||
| ), | ||
| pytest.param( | ||
| HistGradientBoostingRegressor(loss="poisson", max_iter=5), | ||
| "positive_regression", | ||
| id="HistGradientBoostingRegressor-poisson", | ||
| ), | ||
| ], | ||
| ) | ||
| def test_gradient_boosting_estimators_have_no_untrusted_types(estimator, problem_type): | ||
| set_random_state(estimator, random_state=0) | ||
|
|
||
| if problem_type == "binary": | ||
| X, y = make_classification( | ||
| n_samples=N_SAMPLES, | ||
| n_features=N_FEATURES, | ||
| n_classes=2, | ||
| n_informative=5, | ||
| random_state=0, | ||
| ) | ||
| elif problem_type == "multiclass": | ||
| X, y = make_classification( | ||
| n_samples=140, | ||
| n_features=N_FEATURES, | ||
| n_classes=3, | ||
| n_informative=8, | ||
| n_clusters_per_class=1, | ||
| random_state=0, | ||
| ) | ||
| elif problem_type == "positive_regression": | ||
| X, y = make_regression( | ||
| n_samples=N_SAMPLES, | ||
| n_features=N_FEATURES, | ||
| random_state=0, | ||
| ) | ||
| y = np.abs(y) + 1 | ||
| else: | ||
| X, y = make_regression( | ||
| n_samples=N_SAMPLES, | ||
| n_features=N_FEATURES, | ||
| random_state=0, | ||
| ) | ||
|
|
||
| with warnings.catch_warnings(): | ||
| warnings.filterwarnings("ignore", module="sklearn") | ||
| estimator.fit(X, y) | ||
|
|
||
| dumped = dumps(estimator) | ||
|
|
||
| assert get_untrusted_types(data=dumped) == [] | ||
|
|
||
| loaded = loads(dumped) | ||
| assert_method_outputs_equal(estimator, loaded, X) | ||
|
|
||
|
|
||
| @pytest.mark.skipif( | ||
| SKLEARN_VERSION < parse_version("1.4"), | ||
| reason="CyHalfMultinomialLoss is not used by GradientBoosting before sklearn 1.4.", | ||
| ) | ||
| def test_cyhalfmultinomialloss_is_serialized_under_sklearn_module(): | ||
| estimator = GradientBoostingClassifier(loss="log_loss", n_estimators=5) | ||
| set_random_state(estimator, random_state=0) | ||
| X, y = make_classification( | ||
| n_samples=140, | ||
| n_features=N_FEATURES, | ||
| n_classes=3, | ||
| n_informative=8, | ||
| n_clusters_per_class=1, | ||
| random_state=0, | ||
| ) | ||
|
|
||
| with warnings.catch_warnings(): | ||
| warnings.filterwarnings("ignore", module="sklearn") | ||
| estimator.fit(X, y) | ||
|
|
||
| dumped = dumps(estimator) | ||
| with ZipFile(io.BytesIO(dumped), "r") as zip_file: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not sure if going through the zip file is a good idea, we should save / load and check if the loaded object is correct, with correct loaded attributes. |
||
| schema = json.loads(zip_file.read("schema.json")) | ||
|
|
||
| found = [] | ||
|
|
||
| def walk(obj): | ||
| if isinstance(obj, dict): | ||
| if obj.get("__class__") == "CyHalfMultinomialLoss": | ||
| found.append(obj) | ||
| for value in obj.values(): | ||
| walk(value) | ||
| elif isinstance(obj, list): | ||
| for value in obj: | ||
| walk(value) | ||
|
|
||
| walk(schema) | ||
|
|
||
| assert len(found) == 1 | ||
| assert found[0]["__module__"] == "sklearn._loss._loss" | ||
| assert found[0]["__loader__"] == "SklearnInternalObjectNode" | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "estimator", _unsupported_estimators(), ids=_get_check_estimator_ids | ||
| ) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is more of a test. Import shouldn't raise. Alternatively, we can filter out here anything which doesn't start with
sklearn.