Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/changes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ skops Changelog

v0.9
----
- Add support for `quantile-forest <https://github.com/zillow/quantile-forest>`__
estimators. :pr:`384` by :user:`Reid Johnson <reidjohnson>`.
- Fix an issue with visualizing Skops files for `scikit-learn` tree estimators.
:pr:`386` by :user:`Reid Johnson <reidjohnson>`.

Expand Down
1 change: 1 addition & 0 deletions skops/_min_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"scikit-learn-intelex": ("2021.7.1", "docs", None),
"huggingface_hub": ("0.10.1", "install", None),
"tabulate": ("0.8.8", "install", None),
"quantile-forest": ("1.0.0", "tests", None),
"pytest": (PYTEST_MIN_VERSION, "tests", None),
"pytest-cov": ("2.9.0", "tests", None),
"flake8": ("3.8.2", "tests", None),
Expand Down
2 changes: 1 addition & 1 deletion skops/io/_persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# We load the dispatch functions from the corresponding modules and register
# them. Old protocols are found in the 'old/' directory, with the protocol
# version appended to the corresponding module name.
modules = ["._general", "._numpy", "._scipy", "._sklearn"]
modules = ["._general", "._numpy", "._scipy", "._sklearn", "._quantile_forest"]
Comment thread
adrinjalali marked this conversation as resolved.
modules.extend([".old._general_v0", ".old._numpy_v0"])
for module_name in modules:
# register exposed functions for get_state and get_tree
Expand Down
51 changes: 51 additions & 0 deletions skops/io/_quantile_forest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from __future__ import annotations

from typing import Any, Sequence

from ._protocol import PROTOCOL
from ._sklearn import ReduceNode, reduce_get_state
from ._utils import LoadContext, SaveContext

try:
from quantile_forest._quantile_forest_fast import QuantileForest
except ImportError:
QuantileForest = None


def quantile_forest_get_state(
obj: Any,
save_context: SaveContext,
) -> dict[str, Any]:
state = reduce_get_state(obj, save_context)
state["__loader__"] = "QuantileForestNode"
return state


class QuantileForestNode(ReduceNode):
def __init__(
self,
state: dict[str, Any],
load_context: LoadContext,
trusted: bool | Sequence[str] = False,
) -> None:
if QuantileForest is None:
raise ImportError(
"`quantile_forest` is missing and needs to be installed in order to"
" load this model."
)

super().__init__(
state,
load_context,
constructor=QuantileForest,
trusted=trusted,
)
self.trusted = self._get_trusted(trusted, [])


# tuples of type and function that gets the state of that type
if QuantileForest is not None:
GET_STATE_DISPATCH_FUNCTIONS = [(QuantileForest, quantile_forest_get_state)]

# tuples of type and function that creates the instance of that type
NODE_TYPE_MAPPING = {("QuantileForestNode", PROTOCOL): QuantileForestNode}
46 changes: 46 additions & 0 deletions skops/io/tests/test_external.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,3 +381,49 @@ def test_ranker(self, catboost, cb_rank_data, trusted, boosting_type):
assert_method_outputs_equal(estimator, loaded, X)

visualize(dumped, trusted=trusted)


class TestQuantileForest:
"""Tests for RandomForestQuantileRegressor and ExtraTreesQuantileRegressor"""
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this not missing QuantileForest?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are the user-facing classes which are explicitly tested, QuantileForest is the underlying Cython object shared by the user-facing classes and isn't explicitly tested here (but is serialized by the loading/dumping).


@pytest.fixture(autouse=True)
def capture_stdout(self):
# Mock print and rich.print so that running these tests with pytest -s
# does not spam stdout. Other, more common methods of suppressing
# printing to stdout don't seem to work, perhaps because of pytest.
with patch("builtins.print", Mock()), patch("rich.print", Mock()):
yield

@pytest.fixture(autouse=True)
def quantile_forest(self):
quantile_forest = pytest.importorskip("quantile_forest")
return quantile_forest

@pytest.fixture
def trusted(self):
# TODO: adjust once more types are trusted by default
return [
"quantile_forest._quantile_forest.RandomForestQuantileRegressor",
"quantile_forest._quantile_forest.ExtraTreesQuantileRegressor",
"quantile_forest._quantile_forest_fast.QuantileForest",
]

tree_methods = [
"RandomForestQuantileRegressor",
"ExtraTreesQuantileRegressor",
]

@pytest.mark.parametrize("tree_method", tree_methods)
def test_quantile_forest(self, quantile_forest, regr_data, trusted, tree_method):
cls = getattr(quantile_forest, tree_method)
estimator = cls()
loaded = loads(dumps(estimator), trusted=trusted)
assert_params_equal(estimator.get_params(), loaded.get_params())

X, y = regr_data
estimator.fit(X, y)
dumped = dumps(estimator)
loaded = loads(dumped, trusted=trusted)
assert_method_outputs_equal(estimator, loaded, X)

visualize(dumped, trusted=trusted)