diff --git a/skops/card/_model_card.py b/skops/card/_model_card.py index c978e670..2d0d64ee 100644 --- a/skops/card/_model_card.py +++ b/skops/card/_model_card.py @@ -5,6 +5,7 @@ import re import shutil import tempfile +import joblib from dataclasses import dataclass from pathlib import Path from reprlib import Repr @@ -15,6 +16,7 @@ from tabulate import tabulate # type: ignore import skops +from skops.io import load # Repr attributes can be used to control the behavior of repr aRepr = Repr() @@ -25,7 +27,10 @@ def wrap_as_details(text: str, folded: bool) -> str: if not folded: return text - return f"
\n Click to expand \n\n{text}\n\n
" + return ( + "
\n Click to expand" + f" \n\n{text}\n\n
" + ) def _clean_table(table: str) -> str: @@ -92,7 +97,9 @@ def format(self) -> str: headers = self.table.keys() table = _clean_table( - tabulate(self.table, tablefmt="github", headers=headers, showindex=False) + tabulate( + self.table, tablefmt="github", headers=headers, showindex=False + ) ) return wrap_as_details(table, folded=self.folded) @@ -150,7 +157,9 @@ def metadata_from_config(config_path: Union[str, Path]) -> CardData: task = config.get("sklearn", {}).get("task", None) if task: card_data.tags += [task] - card_data.model_file = config.get("sklearn", {}).get("model", {}).get("file") + card_data.model_file = ( + config.get("sklearn", {}).get("model", {}).get("file") + ) example_input = config.get("sklearn", {}).get("example_input", None) # Documentation on what the widget expects: # https://huggingface.co/docs/hub/models-widgets-examples @@ -172,7 +181,7 @@ class Card: Parameters ---------- - model: estimator object + model: pathlib.path, str, or sklearn estimator object Model that will be documented. model_diagram: bool, default=True @@ -254,7 +263,7 @@ def __init__( model_diagram: bool = True, metadata: Optional[CardData] = None, ) -> None: - self.model = model + self._model = model self.model_diagram = model_diagram self._eval_results = {} # type: ignore self._template_sections: dict[str, str] = {} @@ -302,11 +311,15 @@ def add_plot(self, folded=False, **kwargs: str) -> "Card": Card object. """ for plot_name, plot_path in kwargs.items(): - section = PlotSection(alt_text=plot_name, path=plot_path, folded=folded) + section = PlotSection( + alt_text=plot_name, path=plot_path, folded=folded + ) self._extra_sections.append((plot_name, section)) return self - def add_table(self, folded: bool = False, **kwargs: dict["str", list[Any]]) -> Card: + def add_table( + self, folded: bool = False, **kwargs: dict["str", list[Any]] + ) -> Card: """Add a table to the model card. Add a table to the model card. This can be especially useful when you @@ -373,6 +386,56 @@ def add_metrics(self, **kwargs: str) -> "Card": self._eval_results[metric] = value return self + @property + def model(self): + model = self._load_model(self._model) + if model is not self._model: + self._model = model + return model + + @model.setter + def model(self, model): + self._model = model + + @model.deleter + def model(self): + del self._model + + def _load_model(self, model: Any) -> Any: + """Loads the model if provided a file path, if already a model instance, + return it unmodified. + + Parameters + ---------- + model : pathlib.path, str, or sklearn estimator + Path/str or the actual model instance. If a Path or str, loads the model on first call. + + Returns + ------- + model : object + Model instance. + + """ + if not isinstance(model, (Path, str)): + return model + + model_path = Path(model) + if not model_path.exists(): + raise ValueError("Model file does not exist") + + if model_path.suffix in (".pkl", ".pickle"): + model = joblib.load(model_path) + elif model_path.suffix == ".skops": + model = load(model_path) + else: + msg = ( + f"Cannot interpret model suffix {model_path.suffix}, should be" + " '.pkl', '.pickle' or '.skops'" + ) + raise ValueError(msg) + + return model + def _generate_card(self) -> ModelCard: """Generate the ModelCard object @@ -403,18 +466,23 @@ def _generate_card(self) -> ModelCard: ) else: template_sections["get_started_code"] = ( - "import joblib\nimport json\nimport pandas as pd\nclf =" - f' joblib.load({model_file})\nwith open("config.json") as' + "import joblib\nimport json\nimport pandas as" + " pd\nclf =" + f" joblib.load({model_file})\nwith" + ' open("config.json") as' " f:\n " " config =" " json.load(f)\n" 'clf.predict(pd.DataFrame.from_dict(config["sklearn"]["example_input"]))' ) if self.model_diagram is True: - model_plot_div = re.sub(r"\n\s+", "", str(estimator_html_repr(self.model))) + model_plot_div = re.sub( + r"\n\s+", "", str(estimator_html_repr(self.model)) + ) if model_plot_div.count("sk-top-container") == 1: model_plot_div = model_plot_div.replace( - "sk-top-container", 'sk-top-container" style="overflow: auto;' + "sk-top-container", + 'sk-top-container" style="overflow: auto;', ) model_plot: str | None = model_plot_div else: @@ -439,7 +507,9 @@ def _generate_card(self) -> ModelCard: f"{tmpdirname}/temporary_template.md", ) # create a temporary template with the additional plots - template_sections["template_path"] = f"{tmpdirname}/temporary_template.md" + template_sections[ + "template_path" + ] = f"{tmpdirname}/temporary_template.md" # add extra sections at the end of the template with open(template_sections["template_path"], "a") as template: if self._extra_sections: @@ -520,13 +590,17 @@ def __repr__(self) -> str: model = getattr(self, "model", None) if model: model_str = self._strip_blank(repr(model)) - model_repr = aRepr.repr(f" model={model_str},").strip('"').strip("'") + model_repr = ( + aRepr.repr(f" model={model_str},").strip('"').strip("'") + ) else: model_repr = None # metadata metadata_reprs = [] - for key, val in self.metadata.to_dict().items() if self.metadata else {}: + for key, val in ( + self.metadata.to_dict().items() if self.metadata else {} + ): if key == "widget": metadata_reprs.append(" metadata.widget={...},") continue @@ -540,14 +614,18 @@ def __repr__(self) -> str: template_reprs = [] for key, val in self._template_sections.items(): val = self._strip_blank(repr(val)) - template_reprs.append(aRepr.repr(f" {key}={val},").strip('"').strip("'")) + template_reprs.append( + aRepr.repr(f" {key}={val},").strip('"').strip("'") + ) template_repr = "\n".join(template_reprs) # figures figure_reprs = [] for key, val in self._extra_sections: val = self._strip_blank(repr(val)) - figure_reprs.append(aRepr.repr(f" {key}={val},").strip('"').strip("'")) + figure_reprs.append( + aRepr.repr(f" {key}={val},").strip('"').strip("'") + ) figure_repr = "\n".join(figure_reprs) complete_repr = "Card(\n" diff --git a/skops/card/tests/test_card.py b/skops/card/tests/test_card.py index e2ed4596..a92ebe47 100644 --- a/skops/card/tests/test_card.py +++ b/skops/card/tests/test_card.py @@ -33,6 +33,19 @@ def model_card(model_diagram=True): yield card +@pytest.fixture +def model_card_from_path(suffix, model_diagram=True): + model = fit_model() + save_file = tempfile.mkstemp(suffix=suffix, prefix="skops-test")[1] + if suffix in (".pkl", ".pickle"): + with open(save_file, "wb") as f: + pickle.dump(model, f) + elif suffix == ".skops": + dump(model, save_file) + card = Card(save_file, model_diagram) + yield card + + @pytest.fixture def iris_data(): X, y = load_iris(return_X_y=True, as_frame=True) @@ -82,6 +95,24 @@ def _create_model_card_from_saved_model( return card +def _create_model_card_from_model_path( + destination_path, + iris_data, + save_file, +): + X, y = iris_data + hub_utils.init( + model=save_file, + requirements=[f"scikit-learn=={sklearn.__version__}"], + dst=destination_path, + task="tabular-classification", + data=X, + ) + card = Card(save_file, metadata=metadata_from_config(destination_path)) + card.save(Path(destination_path) / "README.md") + return card + + @pytest.fixture def skops_model_card_metadata_from_config( destination_path, iris_estimator, iris_skops_file, iris_data @@ -100,6 +131,22 @@ def pkl_model_card_metadata_from_config( ) +@pytest.fixture +def skops_model_card_from_path_metadata_from_config( + destination_path, iris_skops_file, iris_data +): + yield _create_model_card_from_model_path( + destination_path, iris_data, iris_skops_file + ) + + +@pytest.fixture +def pkl_model_card_from_path_metadata_from_config( + destination_path, iris_pkl_file, iris_data +): + yield _create_model_card_from_model_path(destination_path, iris_data, iris_pkl_file) + + @pytest.fixture def destination_path(): with tempfile.TemporaryDirectory(prefix="skops-test") as dir_path: @@ -111,11 +158,23 @@ def test_save_model_card(destination_path, model_card): assert (Path(destination_path) / "README.md").exists() +@pytest.mark.parametrize("suffix", [".pkl", ".pickle", ".skops"]) +def test_save_model_card_from_path(destination_path, model_card_from_path): + model_card_from_path.save(Path(destination_path) / "README.md") + assert (Path(destination_path) / "README.md").exists() + + def test_hyperparameter_table(destination_path, model_card): model_card = model_card.render() assert "fit_intercept" in model_card +@pytest.mark.parametrize("suffix", [".pkl", ".pickle", ".skops"]) +def test_hyperparameter_table_from_path(model_card_from_path): + model_card_from_path = model_card_from_path.render() + assert "fit_intercept" in model_card_from_path + + def _strip_multiple_chars(text, char): # _strip_multiple_chars("hi there") == "hi there" # _strip_multiple_chars("|---|--|", "-") == "|-|-|" @@ -144,17 +203,44 @@ def test_plot_model(destination_path, model_card): assert "