diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index 16173cf9..ec28562b 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -17,15 +17,16 @@ jobs: fail-fast: false # need to see which ones fail matrix: os: [ubuntu-latest, windows-latest, macos-latest] - python: ["3.8", "3.9", "3.10"] + python: ["3.8", "3.9", "3.10", "3.11"] # this is to make the CI run on different sklearn versions include: - python: "3.8" sklearn_version: "1.0" - # TODO: add sklearn 1.1 when we add 3.11 support - python: "3.9" - sklearn_version: "1.2" + sklearn_version: "1.1" - python: "3.10" + sklearn_version: "1.2" + - python: "3.11" sklearn_version: "nightly" diff --git a/setup.py b/setup.py index 2262dcea..69d800a5 100644 --- a/setup.py +++ b/setup.py @@ -61,6 +61,7 @@ def setup_package(): "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", "Programming Language :: Python :: Implementation :: CPython", ], python_requires=">=3.8", diff --git a/skops/_min_dependencies.py b/skops/_min_dependencies.py index 35219a12..565263fc 100644 --- a/skops/_min_dependencies.py +++ b/skops/_min_dependencies.py @@ -29,7 +29,8 @@ # required for persistence tests of external libraries "lightgbm": ("3", "tests", None), "xgboost": ("1.6", "tests", None), - "catboost": ("1.0", "tests", None), + # TODO: remove condition when catboost supports python 3.11 + "catboost": ("1.0", "tests", "python_version < '3.11'"), } diff --git a/skops/io/_general.py b/skops/io/_general.py index 10ef9a0f..9bc2254a 100644 --- a/skops/io/_general.py +++ b/skops/io/_general.py @@ -401,10 +401,11 @@ def _construct(self): return instance attrs = self.children["attrs"].construct() - if hasattr(instance, "__setstate__"): - instance.__setstate__(attrs) - else: - instance.__dict__.update(attrs) + if attrs is not None: + if hasattr(instance, "__setstate__"): + instance.__setstate__(attrs) + else: + instance.__dict__.update(attrs) return instance diff --git a/skops/io/_sklearn.py b/skops/io/_sklearn.py index 9da5392a..14ba4a87 100644 --- a/skops/io/_sklearn.py +++ b/skops/io/_sklearn.py @@ -72,13 +72,15 @@ def reduce_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: # call __getstate__ directly. attrs = reduce[2] elif hasattr(obj, "__getstate__"): - attrs = obj.__getstate__() + # since python311 __getstate__ is defined for `object` and might return + # None + attrs = obj.__getstate__() or {} elif hasattr(obj, "__dict__"): attrs = obj.__dict__ else: attrs = {} - if not isinstance(attrs, dict): + if not isinstance(attrs, (dict, tuple)): raise UnsupportedTypeException( f"Objects of type {res['__class__']} not supported yet" ) @@ -119,8 +121,13 @@ def _construct(self): if hasattr(instance, "__setstate__"): instance.__setstate__(attrs) - else: + elif isinstance(attrs, dict): instance.__dict__.update(attrs) + else: + # we (probably) got tuple attrs but cannot setstate with them + raise UnsupportedTypeException( + f"Objects of type {constructor} are not supported yet" + ) return instance diff --git a/skops/io/tests/_utils.py b/skops/io/tests/_utils.py index 6b02c3af..9081a9fc 100644 --- a/skops/io/tests/_utils.py +++ b/skops/io/tests/_utils.py @@ -67,10 +67,13 @@ def _assert_tuples_equal(val1, val2): def _assert_vals_equal(val1, val2): - if hasattr(val1, "__getstate__"): + if type(val1) == type: # e.g. could be np.int64 + assert val1 is val2 + elif hasattr(val1, "__getstate__") and (val1.__getstate__() is not None): # This includes BaseEstimator since they implement __getstate__ and # that returns the parameters as well. - # + # Since Python 3.11, all objects have a __getstate__ but they return + # None by default, in which case this check is not performed. # Some objects return a tuple of parameters, others a dict. state1 = val1.__getstate__() state2 = val2.__getstate__() @@ -126,14 +129,9 @@ def _assert_vals_equal(val1, val2): def assert_params_equal(params1, params2): - # due to https://github.com/scikit-learn/scikit-learn/pull/22094, after - # loading an sklearn estimator, there might be an entry called - # "__sklearn_pickle_version__" in the __dict__ that wasn't there before. We - # just ignore it. - params1.pop("__sklearn_pickle_version__", None) - params2.pop("__sklearn_pickle_version__", None) - # helper function to compare estimator dictionaries of parameters + if params1 is None and params2 is None: + return assert len(params1) == len(params2) assert set(params1.keys()) == set(params2.keys()) for key in params1: