diff --git a/docs/changes.rst b/docs/changes.rst
index 90691b76..7e208380 100644
--- a/docs/changes.rst
+++ b/docs/changes.rst
@@ -19,7 +19,8 @@ v0.4
- Add `model_format` argument to :meth:`skops.hub_utils.init` to be stored in
`config.json` so that we know how to load a model from the repository.
:pr:`242` by `Merve Noyan`_.
-
+- Persistence now supports bytes and bytearrays, added tests to verify that
+ LightGBM, XGBoost, and CatBoost work now. :pr:`244` by `Benjamin Bossan`_.
v0.3
----
diff --git a/docs/persistence.rst b/docs/persistence.rst
index 7fced29b..19959f85 100644
--- a/docs/persistence.rst
+++ b/docs/persistence.rst
@@ -87,6 +87,26 @@ means if you have custom functions (say, a custom function to be used with
most ``numpy`` and ``scipy`` functions should work. Therefore, you can actually
save built-in functions like ``numpy.sqrt``.
+Supported libraries
+-------------------
+
+Skops intends to support all of **scikit-learn**, that is, not only its
+estimators, but also other classes like cross validation splitters. Furthermore,
+most types from **numpy** and **scipy** should be supported, such as (sparse)
+arrays, dtypes, random generators, and ufuncs.
+
+Apart from this core, we plan to support machine learning libraries commonly
+used be the community. So far, those are:
+
+- `LightGBM `_ (scikit-learn API)
+- `XGBoost `_ (scikit-learn API)
+- `CatBoost `_
+
+If you run into a problem using any of the mentioned libraries, this could mean
+there is a bug in skops. Please open an issue on `our issue tracker
+`_ (but please check first if a
+corresponding issue already exists).
+
Roadmap
-------
diff --git a/skops/_min_dependencies.py b/skops/_min_dependencies.py
index a3f1ced7..a60579c4 100644
--- a/skops/_min_dependencies.py
+++ b/skops/_min_dependencies.py
@@ -27,6 +27,10 @@
"matplotlib": ("3.3", "docs, tests", None),
"pandas": ("1", "docs, tests", None),
"typing_extensions": ("3.7", "install", "python_full_version < '3.8'"),
+ # required for persistence tests of external libraries
+ "lightgbm": ("3", "tests", None),
+ "xgboost": ("1.6", "tests", None),
+ "catboost": ("1.0", "tests", None),
}
diff --git a/skops/io/_general.py b/skops/io/_general.py
index 126bcfaf..10ef9a0f 100644
--- a/skops/io/_general.py
+++ b/skops/io/_general.py
@@ -1,6 +1,8 @@
from __future__ import annotations
+import io
import json
+import uuid
from functools import partial
from types import FunctionType, MethodType
from typing import Any, Sequence
@@ -475,12 +477,64 @@ def _construct(self):
return json.loads(self.content)
+def bytes_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]:
+ f_name = f"{uuid.uuid4()}.bin"
+ save_context.zip_file.writestr(f_name, obj)
+ res = {
+ "__class__": obj.__class__.__name__,
+ "__module__": get_module(type(obj)),
+ "__loader__": "BytesNode",
+ "file": f_name,
+ }
+ return res
+
+
+def bytearray_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]:
+ res = bytes_get_state(obj, save_context)
+ res["__loader__"] = "BytearrayNode"
+ return res
+
+
+class BytesNode(Node):
+ def __init__(
+ self,
+ state: dict[str, Any],
+ load_context: LoadContext,
+ trusted: bool | Sequence[str] = False,
+ ) -> None:
+ super().__init__(state, load_context, trusted)
+ self.trusted = self._get_trusted(trusted, [bytes])
+ self.children = {"content": io.BytesIO(load_context.src.read(state["file"]))}
+
+ def _construct(self):
+ content = self.children["content"].getvalue()
+ return content
+
+
+class BytearrayNode(BytesNode):
+ def __init__(
+ self,
+ state: dict[str, Any],
+ load_context: LoadContext,
+ trusted: bool | Sequence[str] = False,
+ ) -> None:
+ super().__init__(state, load_context, trusted)
+ self.trusted = self._get_trusted(trusted, [bytearray])
+
+ def _construct(self):
+ content_bytes = super()._construct()
+ content_bytearray = bytearray(list(content_bytes))
+ return content_bytearray
+
+
# tuples of type and function that gets the state of that type
GET_STATE_DISPATCH_FUNCTIONS = [
(dict, dict_get_state),
(list, list_get_state),
(set, set_get_state),
(tuple, tuple_get_state),
+ (bytes, bytes_get_state),
+ (bytearray, bytearray_get_state),
(slice, slice_get_state),
(FunctionType, function_get_state),
(MethodType, method_get_state),
@@ -494,6 +548,8 @@ def _construct(self):
"ListNode": ListNode,
"SetNode": SetNode,
"TupleNode": TupleNode,
+ "BytesNode": BytesNode,
+ "BytearrayNode": BytearrayNode,
"SliceNode": SliceNode,
"FunctionNode": FunctionNode,
"MethodNode": MethodNode,
diff --git a/skops/io/tests/_utils.py b/skops/io/tests/_utils.py
new file mode 100644
index 00000000..ead14b29
--- /dev/null
+++ b/skops/io/tests/_utils.py
@@ -0,0 +1,170 @@
+import sys
+import warnings
+
+import numpy as np
+from scipy import sparse
+from sklearn.base import BaseEstimator
+from sklearn.utils._testing import assert_allclose_dense_sparse
+
+# TODO: Investigate why that seems to be an issue on MacOS (only observed with
+# Python 3.8)
+ATOL = 1e-6 if sys.platform == "darwin" else 1e-7
+
+
+def _is_steps_like(obj):
+ # helper function to check if an object is something like Pipeline.steps,
+ # i.e. a list of tuples of names and estimators
+ if not isinstance(obj, list): # must be a list
+ return False
+
+ if not obj: # must not be empty
+ return False
+
+ if not isinstance(obj[0], tuple): # must be list of tuples
+ return False
+
+ lens = set(map(len, obj))
+ if not lens == {2}: # all elements must be length 2 tuples
+ return False
+
+ keys, vals = list(zip(*obj))
+
+ if len(keys) != len(set(keys)): # keys must be unique
+ return False
+
+ if not all(map(lambda x: isinstance(x, (type(None), BaseEstimator)), vals)):
+ # values must be BaseEstimators or None
+ return False
+
+ return True
+
+
+def _assert_generic_objects_equal(val1, val2):
+ def _is_builtin(val):
+ # Check if value is a builtin type
+ return getattr(getattr(val, "__class__", {}), "__module__", None) == "builtins"
+
+ if isinstance(val1, (list, tuple, np.ndarray)):
+ assert len(val1) == len(val2)
+ for subval1, subval2 in zip(val1, val2):
+ _assert_generic_objects_equal(subval1, subval2)
+ return
+
+ assert type(val1) == type(val2)
+ if hasattr(val1, "__dict__"):
+ assert_params_equal(val1.__dict__, val2.__dict__)
+ elif _is_builtin(val1):
+ assert val1 == val2
+ else:
+ # not a normal Python class, could be e.g. a Cython class
+ assert val1.__reduce__() == val2.__reduce__()
+
+
+def _assert_tuples_equal(val1, val2):
+ assert len(val1) == len(val2)
+ for subval1, subval2 in zip(val1, val2):
+ _assert_vals_equal(subval1, subval2)
+
+
+def _assert_vals_equal(val1, val2):
+ if hasattr(val1, "__getstate__"):
+ # This includes BaseEstimator since they implement __getstate__ and
+ # that returns the parameters as well.
+ #
+ # Some objects return a tuple of parameters, others a dict.
+ state1 = val1.__getstate__()
+ state2 = val2.__getstate__()
+ assert type(state1) == type(state2)
+ if isinstance(state1, tuple):
+ _assert_tuples_equal(state1, state2)
+ else:
+ assert_params_equal(val1.__getstate__(), val2.__getstate__())
+ elif sparse.issparse(val1):
+ assert sparse.issparse(val2) and ((val1 - val2).nnz == 0)
+ elif isinstance(val1, (np.ndarray, np.generic)):
+ if len(val1.dtype) == 0:
+ # for arrays with at least 2 dimensions, check that contiguity is
+ # preserved
+ if val1.squeeze().ndim > 1:
+ assert val1.flags["C_CONTIGUOUS"] is val2.flags["C_CONTIGUOUS"]
+ assert val1.flags["F_CONTIGUOUS"] is val2.flags["F_CONTIGUOUS"]
+ if val1.dtype == object:
+ assert val2.dtype == object
+ assert val1.shape == val2.shape
+ for subval1, subval2 in zip(val1, val2):
+ _assert_generic_objects_equal(subval1, subval2)
+ else:
+ # simple comparison of arrays with simple dtypes, almost all
+ # arrays are of this sort.
+ np.testing.assert_array_equal(val1, val2)
+ elif len(val1.shape) == 1:
+ # comparing arrays with structured dtypes, but they have to be 1D
+ # arrays. This is what we get from the Tree's state.
+ assert np.all([x == y for x, y in zip(val1, val2)])
+ else:
+ # we don't know what to do with these values, for now.
+ assert False
+ elif isinstance(val1, (tuple, list)):
+ assert len(val1) == len(val2)
+ for subval1, subval2 in zip(val1, val2):
+ _assert_vals_equal(subval1, subval2)
+ elif isinstance(val1, float) and np.isnan(val1):
+ assert np.isnan(val2)
+ elif isinstance(val1, dict):
+ # dictionaries are compared by comparing their values recursively.
+ assert set(val1.keys()) == set(val2.keys())
+ for key in val1:
+ _assert_vals_equal(val1[key], val2[key])
+ elif hasattr(val1, "__dict__") and hasattr(val2, "__dict__"):
+ _assert_vals_equal(val1.__dict__, val2.__dict__)
+ elif isinstance(val1, np.ufunc):
+ assert val1 == val2
+ elif val1.__class__.__module__ == "builtins":
+ assert val1 == val2
+ else:
+ _assert_generic_objects_equal(val1, val2)
+
+
+def assert_params_equal(params1, params2):
+ # helper function to compare estimator dictionaries of parameters
+ assert len(params1) == len(params2)
+ assert set(params1.keys()) == set(params2.keys())
+ for key in params1:
+ with warnings.catch_warnings():
+ # this is to silence the deprecation warning from _DictWithDeprecatedKeys
+ warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn")
+ val1, val2 = params1[key], params2[key]
+ assert type(val1) == type(val2)
+
+ if _is_steps_like(val1):
+ # Deal with Pipeline.steps, FeatureUnion.transformer_list, etc.
+ assert _is_steps_like(val2)
+ val1, val2 = dict(val1), dict(val2)
+
+ if isinstance(val1, (tuple, list)):
+ assert len(val1) == len(val2)
+ for subval1, subval2 in zip(val1, val2):
+ _assert_vals_equal(subval1, subval2)
+ elif isinstance(val1, dict):
+ assert_params_equal(val1, val2)
+ else:
+ _assert_vals_equal(val1, val2)
+
+
+def assert_method_outputs_equal(estimator, loaded, X):
+ # helper function that checks the output of all supported methods
+ for method in [
+ "predict",
+ "predict_proba",
+ "decision_function",
+ "transform",
+ "predict_log_proba",
+ ]:
+ err_msg = (
+ f"{estimator.__class__.__name__}.{method}() doesn't produce the same"
+ " results after loading the persisted model."
+ )
+ if hasattr(estimator, method):
+ X_out1 = getattr(estimator, method)(X)
+ X_out2 = getattr(loaded, method)(X)
+ assert_allclose_dense_sparse(X_out1, X_out2, err_msg=err_msg, atol=ATOL)
diff --git a/skops/io/tests/test_external.py b/skops/io/tests/test_external.py
new file mode 100644
index 00000000..fc58f5f4
--- /dev/null
+++ b/skops/io/tests/test_external.py
@@ -0,0 +1,314 @@
+"""Test persistence of "external" packages
+
+Packages that are not builtins, standard lib, numpy, scipy, or scikit-learn.
+
+"""
+
+import pytest
+from sklearn.datasets import make_classification, make_regression
+
+from skops.io import dumps, loads
+from skops.io.tests._utils import assert_method_outputs_equal, assert_params_equal
+
+# Default settings for generated data
+N_SAMPLES = 30
+N_FEATURES = 10
+N_CLASSES = 4 # for classification only
+
+
+@pytest.fixture(scope="module")
+def clf_data():
+ X, y = make_classification(
+ n_samples=N_SAMPLES,
+ n_classes=N_CLASSES,
+ n_features=N_FEATURES,
+ random_state=0,
+ n_redundant=1,
+ n_informative=N_FEATURES - 1,
+ )
+ return X, y
+
+
+@pytest.fixture(scope="module")
+def regr_data():
+ X, y = make_regression(n_samples=N_SAMPLES, n_features=N_FEATURES, random_state=0)
+ return X, y
+
+
+@pytest.fixture(scope="module")
+def rank_data(clf_data):
+ X, y = clf_data
+ group = [10 for _ in range(N_SAMPLES // 10)]
+ n = sum(group)
+ if N_SAMPLES > n:
+ group[-1] += N_SAMPLES - n
+ assert sum(group) == N_SAMPLES
+ return X, y, group
+
+
+class TestLightGBM:
+ """Tests for LGBMClassifier, LGBMRegressor, LGBMRanker"""
+
+ @pytest.fixture(autouse=True)
+ def lgbm(self):
+ lgbm = pytest.importorskip("lightgbm")
+ return lgbm
+
+ @pytest.fixture
+ def trusted(self):
+ # TODO: adjust once more types are trusted by default
+ return [
+ "collections.defaultdict",
+ "lightgbm.basic.Booster",
+ "lightgbm.sklearn.LGBMClassifier",
+ "lightgbm.sklearn.LGBMRegressor",
+ "lightgbm.sklearn.LGBMRanker",
+ "numpy.int32",
+ "numpy.int64",
+ "sklearn.preprocessing._label.LabelEncoder",
+ ]
+
+ boosting_types = ["gbdt", "dart", "goss", "rf"]
+
+ @pytest.mark.parametrize("boosting_type", boosting_types)
+ def test_classifier(self, lgbm, clf_data, trusted, boosting_type):
+ kw = {}
+ if boosting_type == "rf":
+ kw["bagging_fraction"] = 0.5
+ kw["bagging_freq"] = 2
+
+ estimator = lgbm.LGBMClassifier(boosting_type=boosting_type, **kw)
+ loaded = loads(dumps(estimator), trusted=trusted)
+ assert_params_equal(estimator.get_params(), loaded.get_params())
+
+ X, y = clf_data
+ estimator.fit(X, y)
+ loaded = loads(dumps(estimator), trusted=trusted)
+ assert_method_outputs_equal(estimator, loaded, X)
+
+ @pytest.mark.parametrize("boosting_type", boosting_types)
+ def test_regressor(self, lgbm, regr_data, trusted, boosting_type):
+ kw = {}
+ if boosting_type == "rf":
+ kw["bagging_fraction"] = 0.5
+ kw["bagging_freq"] = 2
+
+ estimator = lgbm.LGBMRegressor(boosting_type=boosting_type, **kw)
+ loaded = loads(dumps(estimator), trusted=trusted)
+ assert_params_equal(estimator.get_params(), loaded.get_params())
+
+ X, y = regr_data
+ estimator.fit(X, y)
+ loaded = loads(dumps(estimator), trusted=trusted)
+ assert_method_outputs_equal(estimator, loaded, X)
+
+ @pytest.mark.parametrize("boosting_type", boosting_types)
+ def test_ranker(self, lgbm, rank_data, trusted, boosting_type):
+ kw = {}
+ if boosting_type == "rf":
+ kw["bagging_fraction"] = 0.5
+ kw["bagging_freq"] = 2
+
+ estimator = lgbm.LGBMRanker(boosting_type=boosting_type, **kw)
+ loaded = loads(dumps(estimator), trusted=trusted)
+ assert_params_equal(estimator.get_params(), loaded.get_params())
+
+ X, y, group = rank_data
+ estimator.fit(X, y, group=group)
+ loaded = loads(dumps(estimator), trusted=trusted)
+ assert_method_outputs_equal(estimator, loaded, X)
+
+
+class TestXGBoost:
+ """Tests for XGBClassifier, XGBRegressor, XGBRFClassifier, XGBRFRegressor, XGBRanker
+
+ Known bugs:
+
+ - When initialzing with tree_method=None, its value resolves to "exact", but
+ after loading, it resolves to "auto" when calling get_params().
+ - When initializing with tree_method='gpu_hist' and gpu_id=None, the
+ latter's value resolves to 0, but after loading, it resolves to -1, when
+ calling get_params()
+
+ These discrepancies occur regardless of skops, so they're a problem in
+ xgboost itself. We assume that this has no practical consequences and thus
+ avoid testing these cases. See https://github.com/dmlc/xgboost/issues/8596
+
+ """
+
+ @pytest.fixture(autouse=True)
+ def xgboost(self):
+ xgboost = pytest.importorskip("xgboost")
+ return xgboost
+
+ @pytest.fixture
+ def trusted(self):
+ # TODO: adjust once more types are trusted by default
+ return [
+ "xgboost.sklearn.XGBClassifier",
+ "xgboost.sklearn.XGBRegressor",
+ "xgboost.sklearn.XGBRFClassifier",
+ "xgboost.sklearn.XGBRFRegressor",
+ "xgboost.sklearn.XGBRanker",
+ "builtins.bytearray",
+ "xgboost.core.Booster",
+ ]
+
+ boosters = ["gbtree", "gblinear", "dart"]
+ tree_methods = ["approx", "hist", "auto"]
+
+ @pytest.mark.parametrize("booster", boosters)
+ @pytest.mark.parametrize("tree_method", tree_methods)
+ def test_classifier(self, xgboost, clf_data, trusted, booster, tree_method):
+ if (booster == "gblinear") and (tree_method != "approx"):
+ # This parameter combination is not supported in XGBoost
+ return
+
+ estimator = xgboost.XGBClassifier(booster=booster, tree_method=tree_method)
+ loaded = loads(dumps(estimator), trusted=trusted)
+ assert_params_equal(estimator.get_params(), loaded.get_params())
+
+ X, y = clf_data
+ estimator.fit(X, y)
+ loaded = loads(dumps(estimator), trusted=trusted)
+ assert_method_outputs_equal(estimator, loaded, X)
+
+ @pytest.mark.parametrize("booster", boosters)
+ @pytest.mark.parametrize("tree_method", tree_methods)
+ def test_regressor(self, xgboost, regr_data, trusted, booster, tree_method):
+ if (booster == "gblinear") and (tree_method != "approx"):
+ # This parameter combination is not supported in XGBoost
+ return
+
+ estimator = xgboost.XGBRegressor(booster=booster, tree_method=tree_method)
+ loaded = loads(dumps(estimator), trusted=trusted)
+ assert_params_equal(estimator.get_params(), loaded.get_params())
+
+ X, y = regr_data
+ estimator.fit(X, y)
+ loaded = loads(dumps(estimator), trusted=trusted)
+ assert_method_outputs_equal(estimator, loaded, X)
+
+ @pytest.mark.parametrize("booster", boosters)
+ @pytest.mark.parametrize("tree_method", tree_methods)
+ def test_rf_classifier(self, xgboost, clf_data, trusted, booster, tree_method):
+ if (booster == "gblinear") and (tree_method != "approx"):
+ # This parameter combination is not supported in XGBoost
+ return
+
+ estimator = xgboost.XGBRFClassifier(booster=booster, tree_method=tree_method)
+ loaded = loads(dumps(estimator), trusted=trusted)
+ assert_params_equal(estimator.get_params(), loaded.get_params())
+
+ X, y = clf_data
+ estimator.fit(X, y)
+ loaded = loads(dumps(estimator), trusted=trusted)
+ assert_method_outputs_equal(estimator, loaded, X)
+
+ @pytest.mark.parametrize("booster", boosters)
+ @pytest.mark.parametrize("tree_method", tree_methods)
+ def test_rf_regressor(self, xgboost, regr_data, trusted, booster, tree_method):
+ if (booster == "gblinear") and (tree_method != "approx"):
+ # This parameter combination is not supported in XGBoost
+ return
+
+ estimator = xgboost.XGBRFRegressor(booster=booster, tree_method=tree_method)
+ loaded = loads(dumps(estimator), trusted=trusted)
+ assert_params_equal(estimator.get_params(), loaded.get_params())
+
+ X, y = regr_data
+ estimator.fit(X, y)
+ loaded = loads(dumps(estimator), trusted=trusted)
+ assert_method_outputs_equal(estimator, loaded, X)
+
+ @pytest.mark.parametrize("booster", boosters)
+ @pytest.mark.parametrize("tree_method", tree_methods)
+ def test_ranker(self, xgboost, rank_data, trusted, booster, tree_method):
+ if (booster == "gblinear") and (tree_method != "approx"):
+ # This parameter combination is not supported in XGBoost
+ return
+
+ estimator = xgboost.XGBRanker(booster=booster, tree_method=tree_method)
+ loaded = loads(dumps(estimator), trusted=trusted)
+ assert_params_equal(estimator.get_params(), loaded.get_params())
+
+ X, y, group = rank_data
+ estimator.fit(X, y, group=group)
+ loaded = loads(dumps(estimator), trusted=trusted)
+ assert_method_outputs_equal(estimator, loaded, X)
+
+
+class TestCatboost:
+ """Tests for CatBoostClassifier, CatBoostRegressor, and CatBoostRanker"""
+
+ # CatBoost data is a little different so that it works as categorical data
+ @pytest.fixture(scope="module")
+ def cb_clf_data(self, clf_data):
+ X, y = clf_data
+ X = (X - X.min()).astype(int)
+ return X, y
+
+ @pytest.fixture(scope="module")
+ def cb_regr_data(self, regr_data):
+ X, y = regr_data
+ X = (X - X.min()).astype(int)
+ return X, y
+
+ @pytest.fixture(scope="module")
+ def cb_rank_data(self, rank_data):
+ X, y, group = rank_data
+ X = (X - X.min()).astype(int)
+ group_id = sum([[i] * n for i, n in enumerate(group)], [])
+ return X, y, group_id
+
+ @pytest.fixture(autouse=True)
+ def catboost(self):
+ catboost = pytest.importorskip("catboost")
+ return catboost
+
+ @pytest.fixture
+ def trusted(self):
+ # TODO: adjust once more types are trusted by default
+ return [
+ "builtins.bytes",
+ "numpy.float32",
+ "numpy.float64",
+ "catboost.core.CatBoostClassifier",
+ "catboost.core.CatBoostRegressor",
+ "catboost.core.CatBoostRanker",
+ ]
+
+ boosting_types = ["Ordered", "Plain"]
+
+ @pytest.mark.parametrize("boosting_type", boosting_types)
+ def test_classifier(self, catboost, cb_clf_data, trusted, boosting_type):
+ estimator = catboost.CatBoostClassifier(boosting_type=boosting_type)
+ loaded = loads(dumps(estimator), trusted=trusted)
+ assert_params_equal(estimator.get_params(), loaded.get_params())
+
+ X, y = cb_clf_data
+ estimator.fit(X, y, cat_features=[0, 1])
+ loaded = loads(dumps(estimator), trusted=trusted)
+ assert_method_outputs_equal(estimator, loaded, X)
+
+ @pytest.mark.parametrize("boosting_type", boosting_types)
+ def test_regressor(self, catboost, cb_regr_data, trusted, boosting_type):
+ estimator = catboost.CatBoostRegressor(boosting_type=boosting_type)
+ loaded = loads(dumps(estimator), trusted=trusted)
+ assert_params_equal(estimator.get_params(), loaded.get_params())
+
+ X, y = cb_regr_data
+ estimator.fit(X, y, cat_features=[0, 1])
+ loaded = loads(dumps(estimator), trusted=trusted)
+ assert_method_outputs_equal(estimator, loaded, X)
+
+ @pytest.mark.parametrize("boosting_type", boosting_types)
+ def test_ranker(self, catboost, cb_rank_data, trusted, boosting_type):
+ estimator = catboost.CatBoostRanker(boosting_type=boosting_type)
+ loaded = loads(dumps(estimator), trusted=trusted)
+ assert_params_equal(estimator.get_params(), loaded.get_params())
+
+ X, y, group_id = cb_rank_data
+ estimator.fit(X, y, cat_features=[0, 1], group_id=group_id)
+ loaded = loads(dumps(estimator), trusted=trusted)
+ assert_method_outputs_equal(estimator, loaded, X)
diff --git a/skops/io/tests/test_persist.py b/skops/io/tests/test_persist.py
index 012ab7f7..bd9c1e20 100644
--- a/skops/io/tests/test_persist.py
+++ b/skops/io/tests/test_persist.py
@@ -2,7 +2,6 @@
import inspect
import io
import json
-import sys
import warnings
from collections import Counter
from functools import partial, wraps
@@ -41,11 +40,7 @@
)
from sklearn.utils import all_estimators, check_random_state
from sklearn.utils._tags import _safe_tags
-from sklearn.utils._testing import (
- SkipTest,
- assert_allclose_dense_sparse,
- set_random_state,
-)
+from sklearn.utils._testing import SkipTest, set_random_state
from sklearn.utils.estimator_checks import (
_construct_instance,
_enforce_estimator_tags_y,
@@ -59,15 +54,12 @@
from skops.io._trusted_types import SKLEARN_ESTIMATOR_TYPE_NAMES
from skops.io._utils import LoadContext, SaveContext, _get_state, get_state
from skops.io.exceptions import UnsupportedTypeException
+from skops.io.tests._utils import assert_method_outputs_equal, assert_params_equal
# Default settings for X
N_SAMPLES = 50
N_FEATURES = 20
-# TODO: Investigate why that seems to be an issue on MacOS (only observed with
-# Python 3.8)
-ATOL = 1e-6 if sys.platform == "darwin" else 1e-7
-
@pytest.fixture(autouse=True, scope="module")
def debug_dispatch_functions():
@@ -258,146 +250,6 @@ def _unsupported_estimators(type_filter=None):
yield estimator
-def _is_steps_like(obj):
- # helper function to check if an object is something like Pipeline.steps,
- # i.e. a list of tuples of names and estimators
- if not isinstance(obj, list): # must be a list
- return False
-
- if not obj: # must not be empty
- return False
-
- if not isinstance(obj[0], tuple): # must be list of tuples
- return False
-
- lens = set(map(len, obj))
- if not lens == {2}: # all elements must be length 2 tuples
- return False
-
- keys, vals = list(zip(*obj))
-
- if len(keys) != len(set(keys)): # keys must be unique
- return False
-
- if not all(map(lambda x: isinstance(x, (type(None), BaseEstimator)), vals)):
- # values must be BaseEstimators or None
- return False
-
- return True
-
-
-def _assert_generic_objects_equal(val1, val2):
- def _is_builtin(val):
- # Check if value is a builtin type
- return getattr(getattr(val, "__class__", {}), "__module__", None) == "builtins"
-
- if isinstance(val1, (list, tuple, np.ndarray)):
- assert len(val1) == len(val2)
- for subval1, subval2 in zip(val1, val2):
- _assert_generic_objects_equal(subval1, subval2)
- return
-
- assert type(val1) == type(val2)
- if hasattr(val1, "__dict__"):
- assert_params_equal(val1.__dict__, val2.__dict__)
- elif _is_builtin(val1):
- assert val1 == val2
- else:
- # not a normal Python class, could be e.g. a Cython class
- assert val1.__reduce__() == val2.__reduce__()
-
-
-def _assert_tuples_equal(val1, val2):
- assert len(val1) == len(val2)
- for subval1, subval2 in zip(val1, val2):
- _assert_vals_equal(subval1, subval2)
-
-
-def _assert_vals_equal(val1, val2):
- if hasattr(val1, "__getstate__"):
- # This includes BaseEstimator since they implement __getstate__ and
- # that returns the parameters as well.
- #
- # Some objects return a tuple of parameters, others a dict.
- state1 = val1.__getstate__()
- state2 = val2.__getstate__()
- assert type(state1) == type(state2)
- if isinstance(state1, tuple):
- _assert_tuples_equal(state1, state2)
- else:
- assert_params_equal(val1.__getstate__(), val2.__getstate__())
- elif sparse.issparse(val1):
- assert sparse.issparse(val2) and ((val1 - val2).nnz == 0)
- elif isinstance(val1, (np.ndarray, np.generic)):
- if len(val1.dtype) == 0:
- # for arrays with at least 2 dimensions, check that contiguity is
- # preserved
- if val1.squeeze().ndim > 1:
- assert val1.flags["C_CONTIGUOUS"] is val2.flags["C_CONTIGUOUS"]
- assert val1.flags["F_CONTIGUOUS"] is val2.flags["F_CONTIGUOUS"]
- if val1.dtype == object:
- assert val2.dtype == object
- assert val1.shape == val2.shape
- for subval1, subval2 in zip(val1, val2):
- _assert_generic_objects_equal(subval1, subval2)
- else:
- # simple comparison of arrays with simple dtypes, almost all
- # arrays are of this sort.
- np.testing.assert_array_equal(val1, val2)
- elif len(val1.shape) == 1:
- # comparing arrays with structured dtypes, but they have to be 1D
- # arrays. This is what we get from the Tree's state.
- assert np.all([x == y for x, y in zip(val1, val2)])
- else:
- # we don't know what to do with these values, for now.
- assert False
- elif isinstance(val1, (tuple, list)):
- assert len(val1) == len(val2)
- for subval1, subval2 in zip(val1, val2):
- _assert_vals_equal(subval1, subval2)
- elif isinstance(val1, float) and np.isnan(val1):
- assert np.isnan(val2)
- elif isinstance(val1, dict):
- # dictionaries are compared by comparing their values recursively.
- assert set(val1.keys()) == set(val2.keys())
- for key in val1:
- _assert_vals_equal(val1[key], val2[key])
- elif hasattr(val1, "__dict__") and hasattr(val2, "__dict__"):
- _assert_vals_equal(val1.__dict__, val2.__dict__)
- elif isinstance(val1, np.ufunc):
- assert val1 == val2
- elif val1.__class__.__module__ == "builtins":
- assert val1 == val2
- else:
- _assert_generic_objects_equal(val1, val2)
-
-
-def assert_params_equal(params1, params2):
- # helper function to compare estimator dictionaries of parameters
- assert len(params1) == len(params2)
- assert set(params1.keys()) == set(params2.keys())
- for key in params1:
- with warnings.catch_warnings():
- # this is to silence the deprecation warning from _DictWithDeprecatedKeys
- warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn")
- val1, val2 = params1[key], params2[key]
- assert type(val1) == type(val2)
-
- if _is_steps_like(val1):
- # Deal with Pipeline.steps, FeatureUnion.transformer_list, etc.
- assert _is_steps_like(val2)
- val1, val2 = dict(val1), dict(val2)
-
- if isinstance(val1, (tuple, list)):
- assert len(val1) == len(val2)
- for subval1, subval2 in zip(val1, val2):
- _assert_vals_equal(subval1, subval2)
- elif isinstance(val1, dict):
- assert_params_equal(val1, val2)
- else:
- _assert_vals_equal(val1, val2)
-
-
@pytest.mark.parametrize(
"estimator", _tested_estimators(), ids=_get_check_estimator_ids
)
@@ -493,22 +345,7 @@ def test_can_persist_fitted(estimator):
assert_params_equal(estimator.__dict__, loaded.__dict__)
assert not any(type_ in SKLEARN_ESTIMATOR_TYPE_NAMES for type_ in untrusted_types)
-
- for method in [
- "predict",
- "predict_proba",
- "decision_function",
- "transform",
- "predict_log_proba",
- ]:
- err_msg = (
- f"{estimator.__class__.__name__}.{method}() doesn't produce the same"
- " results after loading the persisted model."
- )
- if hasattr(estimator, method):
- X_pred1 = getattr(estimator, method)(X)
- X_pred2 = getattr(loaded, method)(X)
- assert_allclose_dense_sparse(X_pred1, X_pred2, err_msg=err_msg, atol=ATOL)
+ assert_method_outputs_equal(estimator, loaded, X)
@pytest.mark.parametrize(
@@ -1002,3 +839,29 @@ def test_when_given_object_referenced_twice_loads_as_one_object(obj):
persisted_object = loads(dumps(an_object), trusted=True)
assert persisted_object["obj_1"] is persisted_object["obj_2"]
+
+
+class EstimatorWithBytes(BaseEstimator):
+ def fit(self, X, y, **fit_params):
+ self.bytes_ = b"hello"
+ self.bytearray_ = bytearray([0, 1, 2, 253, 254, 255])
+ return self
+
+
+def test_estimator_with_bytes():
+ est = EstimatorWithBytes().fit(None, None)
+ loaded = loads(dumps(est), trusted=True)
+ assert_params_equal(est.__dict__, loaded.__dict__)
+
+
+def test_estimator_with_bytes_files_created(tmp_path):
+ est = EstimatorWithBytes().fit(None, None)
+ f_name = tmp_path / "estimator.skops"
+ dump(est, f_name)
+ file = Path(f_name)
+ assert file.exists()
+
+ with ZipFile(f_name, "r") as input_zip:
+ files = input_zip.namelist()
+ bin_files = [file for file in files if file.endswith(".bin")]
+ assert len(bin_files) == 2