diff --git a/docs/model_card.rst b/docs/model_card.rst index d3b726d4..0128a5e0 100644 --- a/docs/model_card.rst +++ b/docs/model_card.rst @@ -79,7 +79,10 @@ plots, save them on disk and then add them to the card by passing the path name to the :meth:`.Card.add_plot` method. For tables, you can pass either dictionaries with the key being the header and the values being list of row entries, or a pandas ``DataFrame``; use the :meth:`.Card.add_table` method for -this. +this. If you would like to add permutation importance results, you can pass +your importances to :meth:`.Card.add_permutation_importances`. If you want to +have multiple importance plots, you should pass a file name and a title for the +plot. This will create a boxplot and write it to the model card for you. To add content to an existing subsection, or create a new subsection, use a ``"/"`` to indicate the subsection. E.g. let's assume you would like to add a diff --git a/examples/plot_model_card.py b/examples/plot_model_card.py index 7a5ff3b3..5e68b7e0 100644 --- a/examples/plot_model_card.py +++ b/examples/plot_model_card.py @@ -20,6 +20,7 @@ from sklearn.datasets import load_breast_cancer from sklearn.ensemble import HistGradientBoostingClassifier from sklearn.experimental import enable_halving_search_cv # noqa +from sklearn.inspection import permutation_importance from sklearn.metrics import ( ConfusionMatrixDisplay, accuracy_score, @@ -153,6 +154,14 @@ **{"Model description/Evaluation Results/Confusion Matrix": "confusion_matrix.png"} ) +importances = permutation_importance(model, X_test, y_test, n_repeats=10) +model_card.add_permutation_importances( + importances, + X_test.columns, + plot_file=Path(local_repo) / "importance.png", + plot_name="Permutation Importance", +) + cv_results = model.cv_results_ clf_report = classification_report( y_test, y_pred, output_dict=True, target_names=["malignant", "benign"] diff --git a/skops/card/_model_card.py b/skops/card/_model_card.py index ede501bd..65e27299 100644 --- a/skops/card/_model_card.py +++ b/skops/card/_model_card.py @@ -17,6 +17,7 @@ from skops.card._templates import CONTENT_PLACEHOLDER, SKOPS_TEMPLATE, Templates from skops.io import load +from skops.utils.importutils import import_or_raise # Repr attributes can be used to control the behavior of repr aRepr = Repr() @@ -206,7 +207,7 @@ def split_subsection_names(key: str) -> list[str]: def _getting_started_code( - file_name: str, model_format: Literal["pickle", "skops"], indent=" " + file_name: str, model_format: Literal["pickle", "skops"], indent: str = " " ) -> list[str]: # get lines of code required to load the model lines = [ @@ -1085,11 +1086,64 @@ def add_metrics( "You can find the details about evaluation process and " "the evaluation results." ) - self._metrics.update(kwargs) self._add_metrics(section, self._metrics, description=description) return self + def add_permutation_importances( + self, + permutation_importances, + columns: Sequence[str], + plot_file: str = "permutation_importances.png", + plot_name: str = "Permutation Importances", + overwrite: bool = False, + ) -> "Card": + """Plots permutation importance and saves it to model card. + + Parameters + ---------- + permutation_importances : sklearn.utils.Bunch + Output of :func:`sklearn.inspection.permutation_importance`. + + columns : str, list or pandas.Index + Column names of the data used to generate importances. + + plot_file : str + Filename for the plot. + + plot_name : str + Name of the plot. + + overwrite : bool (default=False) + Whether to overwrite the permutation importance plot file, if a plot by that + name already exists. + + Returns + ------- + self : object + Card object. + """ + plt = import_or_raise("matplotlib.pyplot", "permutation importance") + + if Path(plot_file).exists() and overwrite is False: + raise ValueError( + f"{str(plot_file)} already exists. Set `overwrite` to `True` or pass a" + " different filename for the plot." + ) + sorted_importances_idx = permutation_importances.importances_mean.argsort() + _, ax = plt.subplots() + ax.boxplot( + x=permutation_importances.importances[sorted_importances_idx].T, + labels=columns[sorted_importances_idx], + vert=False, + ) + ax.set_title(plot_name) + ax.set_xlabel("Decrease in Score") + plt.savefig(plot_file) + self.add_plot(**{plot_name: plot_file}) + + return self + def _add_metrics( self, section: str, diff --git a/skops/card/tests/test_card.py b/skops/card/tests/test_card.py index 28f28f90..5c3731f1 100644 --- a/skops/card/tests/test_card.py +++ b/skops/card/tests/test_card.py @@ -5,13 +5,14 @@ import textwrap from pathlib import Path -import matplotlib.pyplot as plt import numpy as np import pytest import sklearn from huggingface_hub import CardData, metadata_load from sklearn.datasets import load_iris +from sklearn.inspection import permutation_importance from sklearn.linear_model import LinearRegression, LogisticRegression +from sklearn.metrics import f1_score, make_scorer from sklearn.neighbors import KNeighborsClassifier from skops import hub_utils @@ -403,6 +404,96 @@ def test_add_twice(self, model_card): assert text1 == text2 +def test_permutation_importances( + iris_estimator, iris_data, model_card, destination_path +): + X, y = iris_data + result = permutation_importance( + iris_estimator, X, y, n_repeats=10, random_state=42, n_jobs=2 + ) + + model_card.add_permutation_importances( + result, + X.columns, + Path(destination_path) / "importance.png", + "Permutation Importance", + ) + temp_path = Path(destination_path) / "importance.png" + assert f"![Permutation Importance]({temp_path}" in model_card.render() + + +def test_multiple_permutation_importances( + iris_estimator, iris_data, model_card, destination_path +): + X, y = iris_data + result = permutation_importance( + iris_estimator, X, y, n_repeats=10, random_state=42, n_jobs=2 + ) + model_card.add_permutation_importances( + result, X.columns, plot_file=Path(destination_path) / "importance.png" + ) + f1 = make_scorer(f1_score, average="micro") + result = permutation_importance( + iris_estimator, X, y, scoring=f1, n_repeats=10, random_state=42, n_jobs=2 + ) + model_card.add_permutation_importances( + result, + X.columns, + plot_file=Path(destination_path) / "f1_importance.png", + plot_name="Permutation Importance on f1", + ) + # check for default one + temp_path = Path(destination_path) / "importance.png" + assert f"![Permutation Importances]({temp_path}" in model_card.render() + # check for F1 + temp_path_f1 = Path(destination_path) / "f1_importance.png" + assert f"![Permutation Importance on f1]({temp_path_f1}" in model_card.render() + + +def test_duplicate_permutation_importances( + iris_estimator, iris_data, model_card, destination_path +): + X, y = iris_data + result = permutation_importance( + iris_estimator, X, y, n_repeats=10, random_state=42, n_jobs=2 + ) + plot_path = os.path.join(destination_path, "importance.png") + model_card.add_permutation_importances(result, X.columns, plot_file=plot_path) + with pytest.raises( + ValueError, + match=( + "already exists. Set `overwrite` to `True` or pass a" + " different filename for the plot." + ), + ): + model_card.add_permutation_importances( + result, + X.columns, + plot_file=plot_path, + plot_name="Permutation Importance on f1", + ) + + +def test_duplicate_permutation_importances_overwrite( + iris_estimator, iris_data, model_card, destination_path +): + X, y = iris_data + result = permutation_importance( + iris_estimator, X, y, n_repeats=10, random_state=42, n_jobs=2 + ) + plot_path = os.path.join(destination_path, "importance.png") + model_card.add_permutation_importances(result, X.columns, plot_file=plot_path) + + model_card.add_permutation_importances( + result, + X.columns, + plot_file=plot_path, + plot_name="Permutation Importance on f1", + overwrite=True, + ) + assert f"![Permutation Importance on f1]({plot_path}" in model_card.render() + + class TestAddGetStartedCode: """Tests for getting started code""" @@ -856,6 +947,8 @@ def test_delete_empty_key_subsection_raises(self, model_card): class TestAddPlot: def test_add_plot(self, destination_path, model_card): + import matplotlib.pyplot as plt + plt.plot([4, 5, 6, 7]) plt.savefig(Path(destination_path) / "fig1.png") model_card = model_card.add_plot(fig1="fig1.png") @@ -863,6 +956,8 @@ def test_add_plot(self, destination_path, model_card): assert plot_content == "![fig1](fig1.png)" def test_add_plot_to_existing_section(self, destination_path, model_card): + import matplotlib.pyplot as plt + plt.plot([4, 5, 6, 7]) plt.savefig(Path(destination_path) / "fig1.png") model_card = model_card.add_plot(**{"Model description/Figure 1": "fig1.png"}) diff --git a/skops/card/tests/test_parser.py b/skops/card/tests/test_parser.py index b74486fe..d140149b 100644 --- a/skops/card/tests/test_parser.py +++ b/skops/card/tests/test_parser.py @@ -90,7 +90,6 @@ def test_example_model_cards(tmp_path, file_name): path = Path(os.getcwd()) / "skops" / "card" / "tests" / "examples" file0 = path / file_name diff = (path / file_name).with_suffix(".md.diff") - parsed_card = parse_modelcard(file0) file1 = tmp_path / "readme-parsed.md" parsed_card.save(file1) diff --git a/skops/conftest.py b/skops/conftest.py index 4dcaed83..9ee0a4db 100644 --- a/skops/conftest.py +++ b/skops/conftest.py @@ -1,3 +1,4 @@ +import builtins from unittest.mock import patch import pytest @@ -7,7 +8,8 @@ def pandas_not_installed(): # patch import so that it raises an ImportError when trying to import # pandas. This works because pandas is only imported lazily. - orig_import = __import__ + + orig_import = builtins.__import__ def mock_import(name, *args, **kwargs): if name == "pandas": @@ -16,3 +18,28 @@ def mock_import(name, *args, **kwargs): with patch("builtins.__import__", side_effect=mock_import): yield + + +@pytest.fixture +def matplotlib_not_installed(): + # patch import so that it raises an ImportError when trying to import + # matplotlib. This works because matplotlib is only imported lazily. + + # ugly way of removing matplotlib from cached imports + import sys + + for key in list(sys.modules.keys()): + if key.startswith("matplotlib"): + del sys.modules[key] + + orig_import = builtins.__import__ + + def mock_import(name, *args, **kwargs): + if name == "matplotlib": + raise ImportError + return orig_import(name, *args, **kwargs) + + with patch("builtins.__import__", side_effect=mock_import): + yield + + import matplotlib # noqa diff --git a/skops/hub_utils/_hf_hub.py b/skops/hub_utils/_hf_hub.py index cd302c7b..39139a15 100644 --- a/skops/hub_utils/_hf_hub.py +++ b/skops/hub_utils/_hf_hub.py @@ -206,7 +206,7 @@ def _create_config( "text-regression", ], data, - model_format: Literal[ # type: ignore + model_format: Literal[ "skops", "pickle", "auto", @@ -337,7 +337,7 @@ def init( "text-regression", ], data, - model_format: Literal[ # type: ignore + model_format: Literal[ "skops", "pickle", "auto", diff --git a/skops/utils/importutils.py b/skops/utils/importutils.py new file mode 100644 index 00000000..ac81a203 --- /dev/null +++ b/skops/utils/importutils.py @@ -0,0 +1,29 @@ +from importlib import import_module + + +def import_or_raise(module, feature_name): + """Raise error if a given library is not present in the environment. + + Parameters + ---------- + module: str + Name of the module. + + feature_name: str + Name of the feature module is required for. + + Raises + ------ + ModuleNotFoundError + Is raised if a given module is not present in the environment + """ + try: + module = import_module(module) + except ImportError as e: + package = module.split(".")[0] + raise ModuleNotFoundError( + f"{feature_name.capitalize()} requires {package} to be installed. In order" + f" to use {feature_name}, you need to install the package in your current" + " python environment." + ) from e + return module diff --git a/skops/utils/tests/test_importutils.py b/skops/utils/tests/test_importutils.py new file mode 100644 index 00000000..c7424d09 --- /dev/null +++ b/skops/utils/tests/test_importutils.py @@ -0,0 +1,16 @@ +import pytest + +from skops.utils.importutils import import_or_raise + + +@pytest.mark.usefixtures("matplotlib_not_installed") +def test_import_or_raise(): + with pytest.raises( + ModuleNotFoundError, + match=( + "Permutation importance requires matplotlib to be installed. In order" + " to use permutation importance, you need to install the package in" + " your current python environment." + ), + ): + import_or_raise("matplotlib", "permutation importance")