From ee7bde549b3c504e5b1c29ffb492108495c17cb6 Mon Sep 17 00:00:00 2001 From: Reid Johnson Date: Sat, 22 Jul 2023 02:55:00 -0700 Subject: [PATCH 01/11] Initial commit --- skops/io/_persist.py | 2 +- skops/io/_quantile_forest.py | 47 ++++++++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 1 deletion(-) create mode 100644 skops/io/_quantile_forest.py diff --git a/skops/io/_persist.py b/skops/io/_persist.py index 8d81f11a..d69b6123 100644 --- a/skops/io/_persist.py +++ b/skops/io/_persist.py @@ -15,7 +15,7 @@ # We load the dispatch functions from the corresponding modules and register # them. Old protocols are found in the 'old/' directory, with the protocol # version appended to the corresponding module name. -modules = ["._general", "._numpy", "._scipy", "._sklearn"] +modules = ["._general", "._numpy", "._scipy", "._sklearn", "._quantile_forest"] modules.extend([".old._general_v0", ".old._numpy_v0"]) for module_name in modules: # register exposed functions for get_state and get_tree diff --git a/skops/io/_quantile_forest.py b/skops/io/_quantile_forest.py new file mode 100644 index 00000000..8430ab12 --- /dev/null +++ b/skops/io/_quantile_forest.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +from typing import Any, Sequence + +from quantile_forest._quantile_forest_fast import QuantileForest + +from ._protocol import PROTOCOL + +from ._sklearn import reduce_get_state, ReduceNode +from ._utils import LoadContext, SaveContext, get_module + + +def quantile_forest_get_state( + obj: Any, + save_context: SaveContext, + ) -> dict[str, Any]: + state = reduce_get_state(obj, save_context) + state["__loader__"] = "QuantileForestNode" + return state + + +class QuantileForestNode(ReduceNode): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: bool | Sequence[str] = False, + ) -> None: + super().__init__( + state, + load_context, + constructor=QuantileForest, + trusted=trusted, + ) + self.trusted = self._get_trusted( + trusted, + [get_module(QuantileForest) + ".QuantileForest"], + ) + + +# tuples of type and function that gets the state of that type +GET_STATE_DISPATCH_FUNCTIONS = [ + (QuantileForest, quantile_forest_get_state) +] + +# tuples of type and function that creates the instance of that type +NODE_TYPE_MAPPING = {("QuantileForestNode", PROTOCOL): QuantileForestNode} From 522fc4bbfa42ed19756088d114922ef55de68567 Mon Sep 17 00:00:00 2001 From: Reid Johnson Date: Sat, 22 Jul 2023 03:46:50 -0700 Subject: [PATCH 02/11] Update dependencies --- skops/_min_dependencies.py | 1 + 1 file changed, 1 insertion(+) diff --git a/skops/_min_dependencies.py b/skops/_min_dependencies.py index 2c372771..0687b5c6 100644 --- a/skops/_min_dependencies.py +++ b/skops/_min_dependencies.py @@ -15,6 +15,7 @@ "scikit-learn-intelex": ("2021.7.1", "docs", None), "huggingface_hub": ("0.10.1", "install", None), "tabulate": ("0.8.8", "install", None), + "quantile-forest": ("1.0.0", "install", None), "pytest": (PYTEST_MIN_VERSION, "tests", None), "pytest-cov": ("2.9.0", "tests", None), "flake8": ("3.8.2", "tests", None), From 425f8b9ee0cddc073af01e7e01929a6a55b28eda Mon Sep 17 00:00:00 2001 From: Reid Johnson Date: Sun, 23 Jul 2023 02:25:27 -0700 Subject: [PATCH 03/11] Add unit tests for quantile-forest -Adds base unit tests -Patches constructor visualization error --- skops/io/_visualize.py | 3 +++ skops/io/tests/test_external.py | 45 +++++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/skops/io/_visualize.py b/skops/io/_visualize.py index 9df80928..33b6c599 100644 --- a/skops/io/_visualize.py +++ b/skops/io/_visualize.py @@ -232,6 +232,9 @@ def walk_tree( "here: https://github.com/skops-dev/skops/issues" ) + if node_name == "constructor": + return + if isinstance(node, dict): num_nodes = len(node) for i, (key, val) in enumerate(node.items(), start=1): diff --git a/skops/io/tests/test_external.py b/skops/io/tests/test_external.py index 2c90c1cd..02252657 100644 --- a/skops/io/tests/test_external.py +++ b/skops/io/tests/test_external.py @@ -381,3 +381,48 @@ def test_ranker(self, catboost, cb_rank_data, trusted, boosting_type): assert_method_outputs_equal(estimator, loaded, X) visualize(dumped, trusted=trusted) + + +class TestQuantileForest: + """Tests for RandomForestQuantileRegressor and ExtraTreesQuantileRegressor""" + + @pytest.fixture(autouse=True) + def capture_stdout(self): + # Mock print and rich.print so that running these tests with pytest -s + # does not spam stdout. Other, more common methods of suppressing + # printing to stdout don't seem to work, perhaps because of pytest. + with patch("builtins.print", Mock()), patch("rich.print", Mock()): + yield + + @pytest.fixture(autouse=True) + def quantile_forest(self): + quantile_forest = pytest.importorskip("quantile_forest") + return quantile_forest + + @pytest.fixture + def trusted(self): + # TODO: adjust once more types are trusted by default + return [ + "quantile_forest._quantile_forest.RandomForestQuantileRegressor", + "quantile_forest._quantile_forest.ExtraTreesQuantileRegressor", + "quantile_forest._quantile_forest_fast.QuantileForest", + ] + + def test_quantile_forest(self, quantile_forest, regr_data, trusted): + tree_methods = [ + quantile_forest.RandomForestQuantileRegressor, + quantile_forest.ExtraTreesQuantileRegressor, + ] + + for tree_method in tree_methods: + estimator = tree_method() + loaded = loads(dumps(estimator), trusted=trusted) + assert_params_equal(estimator.get_params(), loaded.get_params()) + + X, y = regr_data + estimator.fit(X, y) + dumped = dumps(estimator) + loaded = loads(dumped, trusted=trusted) + assert_method_outputs_equal(estimator, loaded, X) + + visualize(dumped, trusted=trusted) From d9d022b837b6a9f5d5cb68a4d714275da9095bdc Mon Sep 17 00:00:00 2001 From: Reid Johnson Date: Sun, 23 Jul 2023 02:26:17 -0700 Subject: [PATCH 04/11] Update skops/_min_dependencies.py Co-authored-by: Benjamin Bossan --- skops/_min_dependencies.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skops/_min_dependencies.py b/skops/_min_dependencies.py index 0687b5c6..0f1efb93 100644 --- a/skops/_min_dependencies.py +++ b/skops/_min_dependencies.py @@ -15,7 +15,7 @@ "scikit-learn-intelex": ("2021.7.1", "docs", None), "huggingface_hub": ("0.10.1", "install", None), "tabulate": ("0.8.8", "install", None), - "quantile-forest": ("1.0.0", "install", None), + "quantile-forest": ("1.0.0", "tests", None), "pytest": (PYTEST_MIN_VERSION, "tests", None), "pytest-cov": ("2.9.0", "tests", None), "flake8": ("3.8.2", "tests", None), From f39816b311864259c8e442a1029987dad8f96680 Mon Sep 17 00:00:00 2001 From: Reid Johnson Date: Mon, 24 Jul 2023 10:14:50 -0700 Subject: [PATCH 05/11] Linting --- skops/io/_quantile_forest.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/skops/io/_quantile_forest.py b/skops/io/_quantile_forest.py index 8430ab12..f1c07ed7 100644 --- a/skops/io/_quantile_forest.py +++ b/skops/io/_quantile_forest.py @@ -5,15 +5,14 @@ from quantile_forest._quantile_forest_fast import QuantileForest from ._protocol import PROTOCOL - -from ._sklearn import reduce_get_state, ReduceNode +from ._sklearn import ReduceNode, reduce_get_state from ._utils import LoadContext, SaveContext, get_module def quantile_forest_get_state( - obj: Any, - save_context: SaveContext, - ) -> dict[str, Any]: + obj: Any, + save_context: SaveContext, +) -> dict[str, Any]: state = reduce_get_state(obj, save_context) state["__loader__"] = "QuantileForestNode" return state @@ -39,9 +38,7 @@ def __init__( # tuples of type and function that gets the state of that type -GET_STATE_DISPATCH_FUNCTIONS = [ - (QuantileForest, quantile_forest_get_state) -] +GET_STATE_DISPATCH_FUNCTIONS = [(QuantileForest, quantile_forest_get_state)] # tuples of type and function that creates the instance of that type NODE_TYPE_MAPPING = {("QuantileForestNode", PROTOCOL): QuantileForestNode} From e0d59a8f2510472e77db60151624e428c87a758e Mon Sep 17 00:00:00 2001 From: Reid Johnson Date: Mon, 24 Jul 2023 10:20:17 -0700 Subject: [PATCH 06/11] Remove Visualization Fix In favor of separate PR. --- skops/io/_visualize.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/skops/io/_visualize.py b/skops/io/_visualize.py index 33b6c599..9df80928 100644 --- a/skops/io/_visualize.py +++ b/skops/io/_visualize.py @@ -232,9 +232,6 @@ def walk_tree( "here: https://github.com/skops-dev/skops/issues" ) - if node_name == "constructor": - return - if isinstance(node, dict): num_nodes = len(node) for i, (key, val) in enumerate(node.items(), start=1): From 27fa53256816dec71c83f2829d342732c11d2905 Mon Sep 17 00:00:00 2001 From: Reid Johnson Date: Mon, 24 Jul 2023 10:29:26 -0700 Subject: [PATCH 07/11] Remove Default Trusted --- skops/io/_quantile_forest.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/skops/io/_quantile_forest.py b/skops/io/_quantile_forest.py index f1c07ed7..dbe47ed8 100644 --- a/skops/io/_quantile_forest.py +++ b/skops/io/_quantile_forest.py @@ -6,7 +6,7 @@ from ._protocol import PROTOCOL from ._sklearn import ReduceNode, reduce_get_state -from ._utils import LoadContext, SaveContext, get_module +from ._utils import LoadContext, SaveContext def quantile_forest_get_state( @@ -31,10 +31,7 @@ def __init__( constructor=QuantileForest, trusted=trusted, ) - self.trusted = self._get_trusted( - trusted, - [get_module(QuantileForest) + ".QuantileForest"], - ) + self.trusted = self._get_trusted(trusted, []) # tuples of type and function that gets the state of that type From fd4dfffd473e213031edf3940f319371d3a779ba Mon Sep 17 00:00:00 2001 From: Reid Johnson Date: Mon, 24 Jul 2023 10:57:49 -0700 Subject: [PATCH 08/11] Update test Uses pytest.mark.parametrize --- skops/io/tests/test_external.py | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/skops/io/tests/test_external.py b/skops/io/tests/test_external.py index 02252657..aeb7ce72 100644 --- a/skops/io/tests/test_external.py +++ b/skops/io/tests/test_external.py @@ -408,21 +408,22 @@ def trusted(self): "quantile_forest._quantile_forest_fast.QuantileForest", ] - def test_quantile_forest(self, quantile_forest, regr_data, trusted): - tree_methods = [ - quantile_forest.RandomForestQuantileRegressor, - quantile_forest.ExtraTreesQuantileRegressor, - ] + tree_methods = [ + "RandomForestQuantileRegressor", + "ExtraTreesQuantileRegressor", + ] - for tree_method in tree_methods: - estimator = tree_method() - loaded = loads(dumps(estimator), trusted=trusted) - assert_params_equal(estimator.get_params(), loaded.get_params()) + @pytest.mark.parametrize("tree_method", tree_methods) + def test_quantile_forest(self, quantile_forest, regr_data, trusted, tree_method): + cls = getattr(quantile_forest, tree_method) + estimator = cls() + loaded = loads(dumps(estimator), trusted=trusted) + assert_params_equal(estimator.get_params(), loaded.get_params()) - X, y = regr_data - estimator.fit(X, y) - dumped = dumps(estimator) - loaded = loads(dumped, trusted=trusted) - assert_method_outputs_equal(estimator, loaded, X) + X, y = regr_data + estimator.fit(X, y) + dumped = dumps(estimator) + loaded = loads(dumped, trusted=trusted) + assert_method_outputs_equal(estimator, loaded, X) - visualize(dumped, trusted=trusted) + visualize(dumped, trusted=trusted) From e9b0afbe181f57211218ec8a1e5278077aa89053 Mon Sep 17 00:00:00 2001 From: Reid Johnson Date: Thu, 27 Jul 2023 19:32:23 -0700 Subject: [PATCH 09/11] Lazy import quantile-forest --- skops/io/_quantile_forest.py | 72 ++++++++++++++++++++---------------- 1 file changed, 40 insertions(+), 32 deletions(-) diff --git a/skops/io/_quantile_forest.py b/skops/io/_quantile_forest.py index dbe47ed8..b4a72fc6 100644 --- a/skops/io/_quantile_forest.py +++ b/skops/io/_quantile_forest.py @@ -2,40 +2,48 @@ from typing import Any, Sequence -from quantile_forest._quantile_forest_fast import QuantileForest - from ._protocol import PROTOCOL from ._sklearn import ReduceNode, reduce_get_state from ._utils import LoadContext, SaveContext -def quantile_forest_get_state( - obj: Any, - save_context: SaveContext, -) -> dict[str, Any]: - state = reduce_get_state(obj, save_context) - state["__loader__"] = "QuantileForestNode" - return state - - -class QuantileForestNode(ReduceNode): - def __init__( - self, - state: dict[str, Any], - load_context: LoadContext, - trusted: bool | Sequence[str] = False, - ) -> None: - super().__init__( - state, - load_context, - constructor=QuantileForest, - trusted=trusted, - ) - self.trusted = self._get_trusted(trusted, []) - - -# tuples of type and function that gets the state of that type -GET_STATE_DISPATCH_FUNCTIONS = [(QuantileForest, quantile_forest_get_state)] - -# tuples of type and function that creates the instance of that type -NODE_TYPE_MAPPING = {("QuantileForestNode", PROTOCOL): QuantileForestNode} +def _lazy_import(): + from quantile_forest._quantile_forest_fast import QuantileForest + + def quantile_forest_get_state( + obj: Any, + save_context: SaveContext, + ) -> dict[str, Any]: + state = reduce_get_state(obj, save_context) + state["__loader__"] = "QuantileForestNode" + return state + + class QuantileForestNode(ReduceNode): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: bool | Sequence[str] = False, + ) -> None: + super().__init__( + state, + load_context, + constructor=QuantileForest, + trusted=trusted, + ) + self.trusted = self._get_trusted(trusted, []) + + return QuantileForest, QuantileForestNode, quantile_forest_get_state + + +try: + QuantileForest, QuantileForestNode, quantile_forest_get_state = _lazy_import() + + # tuples of type and function that gets the state of that type + GET_STATE_DISPATCH_FUNCTIONS = [(QuantileForest, quantile_forest_get_state)] + + # tuples of type and function that creates the instance of that type + NODE_TYPE_MAPPING = {("QuantileForestNode", PROTOCOL): QuantileForestNode} +except ModuleNotFoundError: + GET_STATE_DISPATCH_FUNCTIONS = [] + NODE_TYPE_MAPPING = {} From c19c7ff1165e5d2c9a66ca3bb1d292df423d95c7 Mon Sep 17 00:00:00 2001 From: Reid Johnson Date: Fri, 28 Jul 2023 08:18:31 -0700 Subject: [PATCH 10/11] Update lazy loading --- skops/io/_quantile_forest.py | 70 ++++++++++++++++++------------------ 1 file changed, 36 insertions(+), 34 deletions(-) diff --git a/skops/io/_quantile_forest.py b/skops/io/_quantile_forest.py index b4a72fc6..a9b14d47 100644 --- a/skops/io/_quantile_forest.py +++ b/skops/io/_quantile_forest.py @@ -6,44 +6,46 @@ from ._sklearn import ReduceNode, reduce_get_state from ._utils import LoadContext, SaveContext - -def _lazy_import(): +try: from quantile_forest._quantile_forest_fast import QuantileForest - - def quantile_forest_get_state( - obj: Any, - save_context: SaveContext, - ) -> dict[str, Any]: - state = reduce_get_state(obj, save_context) - state["__loader__"] = "QuantileForestNode" - return state - - class QuantileForestNode(ReduceNode): - def __init__( - self, - state: dict[str, Any], - load_context: LoadContext, - trusted: bool | Sequence[str] = False, - ) -> None: - super().__init__( - state, - load_context, - constructor=QuantileForest, - trusted=trusted, +except ImportError: + QuantileForest = None + + +def quantile_forest_get_state( + obj: Any, + save_context: SaveContext, +) -> dict[str, Any]: + state = reduce_get_state(obj, save_context) + state["__loader__"] = "QuantileForestNode" + return state + + +class QuantileForestNode(ReduceNode): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: bool | Sequence[str] = False, + ) -> None: + if QuantileForest is None: + raise ImportError( + "`quantile_forest` is missing and needs to be installed in order to" + " load this model." ) - self.trusted = self._get_trusted(trusted, []) - return QuantileForest, QuantileForestNode, quantile_forest_get_state + super().__init__( + state, + load_context, + constructor=QuantileForest, + trusted=trusted, + ) + self.trusted = self._get_trusted(trusted, []) -try: - QuantileForest, QuantileForestNode, quantile_forest_get_state = _lazy_import() - - # tuples of type and function that gets the state of that type +# tuples of type and function that gets the state of that type +if QuantileForest is not None: GET_STATE_DISPATCH_FUNCTIONS = [(QuantileForest, quantile_forest_get_state)] - # tuples of type and function that creates the instance of that type - NODE_TYPE_MAPPING = {("QuantileForestNode", PROTOCOL): QuantileForestNode} -except ModuleNotFoundError: - GET_STATE_DISPATCH_FUNCTIONS = [] - NODE_TYPE_MAPPING = {} +# tuples of type and function that creates the instance of that type +NODE_TYPE_MAPPING = {("QuantileForestNode", PROTOCOL): QuantileForestNode} From 0512fabfdcc36d9eae1356aee67e7393f7264c42 Mon Sep 17 00:00:00 2001 From: Reid Johnson Date: Thu, 3 Aug 2023 09:36:38 -0700 Subject: [PATCH 11/11] Add changelog --- docs/changes.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/changes.rst b/docs/changes.rst index 57f2681f..3be2a186 100644 --- a/docs/changes.rst +++ b/docs/changes.rst @@ -12,6 +12,8 @@ skops Changelog v0.9 ---- +- Add support for `quantile-forest `__ + estimators. :pr:`384` by :user:`Reid Johnson `. - Fix an issue with visualizing Skops files for `scikit-learn` tree estimators. :pr:`386` by :user:`Reid Johnson `.