-
Notifications
You must be signed in to change notification settings - Fork 61
ENH: Quantile Forest Support #384
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ee7bde5
522fc4b
425f8b9
d9d022b
f39816b
e0d59a8
27fa532
fd4dfff
e9b0afb
c19c7ff
6482190
0512fab
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,51 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from typing import Any, Sequence | ||
|
|
||
| from ._protocol import PROTOCOL | ||
| from ._sklearn import ReduceNode, reduce_get_state | ||
| from ._utils import LoadContext, SaveContext | ||
|
|
||
| try: | ||
| from quantile_forest._quantile_forest_fast import QuantileForest | ||
| except ImportError: | ||
| QuantileForest = None | ||
|
|
||
|
|
||
| def quantile_forest_get_state( | ||
| obj: Any, | ||
| save_context: SaveContext, | ||
| ) -> dict[str, Any]: | ||
| state = reduce_get_state(obj, save_context) | ||
| state["__loader__"] = "QuantileForestNode" | ||
| return state | ||
|
|
||
|
|
||
| class QuantileForestNode(ReduceNode): | ||
| def __init__( | ||
| self, | ||
| state: dict[str, Any], | ||
| load_context: LoadContext, | ||
| trusted: bool | Sequence[str] = False, | ||
| ) -> None: | ||
| if QuantileForest is None: | ||
| raise ImportError( | ||
| "`quantile_forest` is missing and needs to be installed in order to" | ||
| " load this model." | ||
| ) | ||
|
|
||
| super().__init__( | ||
| state, | ||
| load_context, | ||
| constructor=QuantileForest, | ||
| trusted=trusted, | ||
| ) | ||
| self.trusted = self._get_trusted(trusted, []) | ||
|
|
||
|
|
||
| # tuples of type and function that gets the state of that type | ||
| if QuantileForest is not None: | ||
| GET_STATE_DISPATCH_FUNCTIONS = [(QuantileForest, quantile_forest_get_state)] | ||
|
|
||
| # tuples of type and function that creates the instance of that type | ||
| NODE_TYPE_MAPPING = {("QuantileForestNode", PROTOCOL): QuantileForestNode} |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -381,3 +381,49 @@ def test_ranker(self, catboost, cb_rank_data, trusted, boosting_type): | |
| assert_method_outputs_equal(estimator, loaded, X) | ||
|
|
||
| visualize(dumped, trusted=trusted) | ||
|
|
||
|
|
||
| class TestQuantileForest: | ||
| """Tests for RandomForestQuantileRegressor and ExtraTreesQuantileRegressor""" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this not missing QuantileForest?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These are the user-facing classes which are explicitly tested, |
||
|
|
||
| @pytest.fixture(autouse=True) | ||
| def capture_stdout(self): | ||
| # Mock print and rich.print so that running these tests with pytest -s | ||
| # does not spam stdout. Other, more common methods of suppressing | ||
| # printing to stdout don't seem to work, perhaps because of pytest. | ||
| with patch("builtins.print", Mock()), patch("rich.print", Mock()): | ||
| yield | ||
|
|
||
| @pytest.fixture(autouse=True) | ||
| def quantile_forest(self): | ||
| quantile_forest = pytest.importorskip("quantile_forest") | ||
| return quantile_forest | ||
|
|
||
| @pytest.fixture | ||
| def trusted(self): | ||
| # TODO: adjust once more types are trusted by default | ||
| return [ | ||
| "quantile_forest._quantile_forest.RandomForestQuantileRegressor", | ||
| "quantile_forest._quantile_forest.ExtraTreesQuantileRegressor", | ||
| "quantile_forest._quantile_forest_fast.QuantileForest", | ||
| ] | ||
|
|
||
| tree_methods = [ | ||
| "RandomForestQuantileRegressor", | ||
| "ExtraTreesQuantileRegressor", | ||
| ] | ||
|
|
||
| @pytest.mark.parametrize("tree_method", tree_methods) | ||
| def test_quantile_forest(self, quantile_forest, regr_data, trusted, tree_method): | ||
| cls = getattr(quantile_forest, tree_method) | ||
| estimator = cls() | ||
| loaded = loads(dumps(estimator), trusted=trusted) | ||
| assert_params_equal(estimator.get_params(), loaded.get_params()) | ||
|
|
||
| X, y = regr_data | ||
| estimator.fit(X, y) | ||
| dumped = dumps(estimator) | ||
| loaded = loads(dumped, trusted=trusted) | ||
| assert_method_outputs_equal(estimator, loaded, X) | ||
|
|
||
| visualize(dumped, trusted=trusted) | ||
Uh oh!
There was an error while loading. Please reload this page.