Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions .github/workflows/build-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,16 @@ jobs:
fail-fast: false # need to see which ones fail
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
python: ["3.8", "3.9", "3.10"]
python: ["3.8", "3.9", "3.10", "3.11"]
# this is to make the CI run on different sklearn versions
include:
- python: "3.8"
sklearn_version: "1.0"
# TODO: add sklearn 1.1 when we add 3.11 support
- python: "3.9"
sklearn_version: "1.2"
sklearn_version: "1.1"
- python: "3.10"
sklearn_version: "1.2"
- python: "3.11"
sklearn_version: "nightly"


Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def setup_package():
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: Implementation :: CPython",
],
python_requires=">=3.8",
Expand Down
3 changes: 2 additions & 1 deletion skops/_min_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
# required for persistence tests of external libraries
"lightgbm": ("3", "tests", None),
"xgboost": ("1.6", "tests", None),
"catboost": ("1.0", "tests", None),
# TODO: remove condition when catboost supports python 3.11
"catboost": ("1.0", "tests", "python_version < '3.11'"),
}


Expand Down
9 changes: 5 additions & 4 deletions skops/io/_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,10 +401,11 @@ def _construct(self):
return instance

attrs = self.children["attrs"].construct()
if hasattr(instance, "__setstate__"):
instance.__setstate__(attrs)
else:
instance.__dict__.update(attrs)
if attrs is not None:
if hasattr(instance, "__setstate__"):
instance.__setstate__(attrs)
else:
instance.__dict__.update(attrs)

return instance

Expand Down
13 changes: 10 additions & 3 deletions skops/io/_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,15 @@ def reduce_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]:
# call __getstate__ directly.
attrs = reduce[2]
elif hasattr(obj, "__getstate__"):
attrs = obj.__getstate__()
# since python311 __getstate__ is defined for `object` and might return
# None
attrs = obj.__getstate__() or {}
elif hasattr(obj, "__dict__"):
attrs = obj.__dict__
else:
attrs = {}

if not isinstance(attrs, dict):
if not isinstance(attrs, (dict, tuple)):
raise UnsupportedTypeException(
f"Objects of type {res['__class__']} not supported yet"
)
Expand Down Expand Up @@ -119,8 +121,13 @@ def _construct(self):

if hasattr(instance, "__setstate__"):
instance.__setstate__(attrs)
else:
elif isinstance(attrs, dict):
instance.__dict__.update(attrs)
else:
# we (probably) got tuple attrs but cannot setstate with them
raise UnsupportedTypeException(
f"Objects of type {constructor} are not supported yet"
)

return instance

Expand Down
16 changes: 7 additions & 9 deletions skops/io/tests/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,13 @@ def _assert_tuples_equal(val1, val2):


def _assert_vals_equal(val1, val2):
if hasattr(val1, "__getstate__"):
if type(val1) == type: # e.g. could be np.int64
assert val1 is val2
elif hasattr(val1, "__getstate__") and (val1.__getstate__() is not None):
# This includes BaseEstimator since they implement __getstate__ and
# that returns the parameters as well.
#
# Since Python 3.11, all objects have a __getstate__ but they return
# None by default, in which case this check is not performed.
# Some objects return a tuple of parameters, others a dict.
state1 = val1.__getstate__()
state2 = val2.__getstate__()
Expand Down Expand Up @@ -126,14 +129,9 @@ def _assert_vals_equal(val1, val2):


def assert_params_equal(params1, params2):
# due to https://github.com/scikit-learn/scikit-learn/pull/22094, after
# loading an sklearn estimator, there might be an entry called
# "__sklearn_pickle_version__" in the __dict__ that wasn't there before. We
# just ignore it.
params1.pop("__sklearn_pickle_version__", None)
params2.pop("__sklearn_pickle_version__", None)

# helper function to compare estimator dictionaries of parameters
if params1 is None and params2 is None:
return
assert len(params1) == len(params2)
assert set(params1.keys()) == set(params2.keys())
for key in params1:
Expand Down