diff --git a/docs/changes.rst b/docs/changes.rst index 57f2681f..3be2a186 100644 --- a/docs/changes.rst +++ b/docs/changes.rst @@ -12,6 +12,8 @@ skops Changelog v0.9 ---- +- Add support for `quantile-forest `__ + estimators. :pr:`384` by :user:`Reid Johnson `. - Fix an issue with visualizing Skops files for `scikit-learn` tree estimators. :pr:`386` by :user:`Reid Johnson `. diff --git a/skops/_min_dependencies.py b/skops/_min_dependencies.py index 2c372771..0f1efb93 100644 --- a/skops/_min_dependencies.py +++ b/skops/_min_dependencies.py @@ -15,6 +15,7 @@ "scikit-learn-intelex": ("2021.7.1", "docs", None), "huggingface_hub": ("0.10.1", "install", None), "tabulate": ("0.8.8", "install", None), + "quantile-forest": ("1.0.0", "tests", None), "pytest": (PYTEST_MIN_VERSION, "tests", None), "pytest-cov": ("2.9.0", "tests", None), "flake8": ("3.8.2", "tests", None), diff --git a/skops/io/_persist.py b/skops/io/_persist.py index 8d81f11a..d69b6123 100644 --- a/skops/io/_persist.py +++ b/skops/io/_persist.py @@ -15,7 +15,7 @@ # We load the dispatch functions from the corresponding modules and register # them. Old protocols are found in the 'old/' directory, with the protocol # version appended to the corresponding module name. -modules = ["._general", "._numpy", "._scipy", "._sklearn"] +modules = ["._general", "._numpy", "._scipy", "._sklearn", "._quantile_forest"] modules.extend([".old._general_v0", ".old._numpy_v0"]) for module_name in modules: # register exposed functions for get_state and get_tree diff --git a/skops/io/_quantile_forest.py b/skops/io/_quantile_forest.py new file mode 100644 index 00000000..a9b14d47 --- /dev/null +++ b/skops/io/_quantile_forest.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +from typing import Any, Sequence + +from ._protocol import PROTOCOL +from ._sklearn import ReduceNode, reduce_get_state +from ._utils import LoadContext, SaveContext + +try: + from quantile_forest._quantile_forest_fast import QuantileForest +except ImportError: + QuantileForest = None + + +def quantile_forest_get_state( + obj: Any, + save_context: SaveContext, +) -> dict[str, Any]: + state = reduce_get_state(obj, save_context) + state["__loader__"] = "QuantileForestNode" + return state + + +class QuantileForestNode(ReduceNode): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: bool | Sequence[str] = False, + ) -> None: + if QuantileForest is None: + raise ImportError( + "`quantile_forest` is missing and needs to be installed in order to" + " load this model." + ) + + super().__init__( + state, + load_context, + constructor=QuantileForest, + trusted=trusted, + ) + self.trusted = self._get_trusted(trusted, []) + + +# tuples of type and function that gets the state of that type +if QuantileForest is not None: + GET_STATE_DISPATCH_FUNCTIONS = [(QuantileForest, quantile_forest_get_state)] + +# tuples of type and function that creates the instance of that type +NODE_TYPE_MAPPING = {("QuantileForestNode", PROTOCOL): QuantileForestNode} diff --git a/skops/io/tests/test_external.py b/skops/io/tests/test_external.py index 2c90c1cd..aeb7ce72 100644 --- a/skops/io/tests/test_external.py +++ b/skops/io/tests/test_external.py @@ -381,3 +381,49 @@ def test_ranker(self, catboost, cb_rank_data, trusted, boosting_type): assert_method_outputs_equal(estimator, loaded, X) visualize(dumped, trusted=trusted) + + +class TestQuantileForest: + """Tests for RandomForestQuantileRegressor and ExtraTreesQuantileRegressor""" + + @pytest.fixture(autouse=True) + def capture_stdout(self): + # Mock print and rich.print so that running these tests with pytest -s + # does not spam stdout. Other, more common methods of suppressing + # printing to stdout don't seem to work, perhaps because of pytest. + with patch("builtins.print", Mock()), patch("rich.print", Mock()): + yield + + @pytest.fixture(autouse=True) + def quantile_forest(self): + quantile_forest = pytest.importorskip("quantile_forest") + return quantile_forest + + @pytest.fixture + def trusted(self): + # TODO: adjust once more types are trusted by default + return [ + "quantile_forest._quantile_forest.RandomForestQuantileRegressor", + "quantile_forest._quantile_forest.ExtraTreesQuantileRegressor", + "quantile_forest._quantile_forest_fast.QuantileForest", + ] + + tree_methods = [ + "RandomForestQuantileRegressor", + "ExtraTreesQuantileRegressor", + ] + + @pytest.mark.parametrize("tree_method", tree_methods) + def test_quantile_forest(self, quantile_forest, regr_data, trusted, tree_method): + cls = getattr(quantile_forest, tree_method) + estimator = cls() + loaded = loads(dumps(estimator), trusted=trusted) + assert_params_equal(estimator.get_params(), loaded.get_params()) + + X, y = regr_data + estimator.fit(X, y) + dumped = dumps(estimator) + loaded = loads(dumped, trusted=trusted) + assert_method_outputs_equal(estimator, loaded, X) + + visualize(dumped, trusted=trusted)