Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 94 additions & 16 deletions skops/card/_model_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import re
import shutil
import tempfile
import joblib
from dataclasses import dataclass
from pathlib import Path
from reprlib import Repr
Expand All @@ -15,6 +16,7 @@
from tabulate import tabulate # type: ignore

import skops
from skops.io import load

# Repr attributes can be used to control the behavior of repr
aRepr = Repr()
Expand All @@ -25,7 +27,10 @@
def wrap_as_details(text: str, folded: bool) -> str:
if not folded:
return text
return f"<details>\n<summary> Click to expand </summary>\n\n{text}\n\n</details>"
return (
"<details>\n<summary> Click to expand"
f" </summary>\n\n{text}\n\n</details>"
)


def _clean_table(table: str) -> str:
Expand Down Expand Up @@ -92,7 +97,9 @@ def format(self) -> str:
headers = self.table.keys()

table = _clean_table(
tabulate(self.table, tablefmt="github", headers=headers, showindex=False)
tabulate(
self.table, tablefmt="github", headers=headers, showindex=False
)
)
return wrap_as_details(table, folded=self.folded)

Expand Down Expand Up @@ -150,7 +157,9 @@ def metadata_from_config(config_path: Union[str, Path]) -> CardData:
task = config.get("sklearn", {}).get("task", None)
if task:
card_data.tags += [task]
card_data.model_file = config.get("sklearn", {}).get("model", {}).get("file")
card_data.model_file = (
config.get("sklearn", {}).get("model", {}).get("file")
)
example_input = config.get("sklearn", {}).get("example_input", None)
# Documentation on what the widget expects:
# https://huggingface.co/docs/hub/models-widgets-examples
Expand All @@ -172,7 +181,7 @@ class Card:

Parameters
----------
model: estimator object
model: pathlib.path, str, or sklearn estimator object
Model that will be documented.

model_diagram: bool, default=True
Expand Down Expand Up @@ -254,7 +263,7 @@ def __init__(
model_diagram: bool = True,
metadata: Optional[CardData] = None,
) -> None:
self.model = model
self._model = model
self.model_diagram = model_diagram
self._eval_results = {} # type: ignore
self._template_sections: dict[str, str] = {}
Expand Down Expand Up @@ -302,11 +311,15 @@ def add_plot(self, folded=False, **kwargs: str) -> "Card":
Card object.
"""
for plot_name, plot_path in kwargs.items():
section = PlotSection(alt_text=plot_name, path=plot_path, folded=folded)
section = PlotSection(
alt_text=plot_name, path=plot_path, folded=folded
)
self._extra_sections.append((plot_name, section))
return self

def add_table(self, folded: bool = False, **kwargs: dict["str", list[Any]]) -> Card:
def add_table(
self, folded: bool = False, **kwargs: dict["str", list[Any]]
) -> Card:
"""Add a table to the model card.

Add a table to the model card. This can be especially useful when you
Expand Down Expand Up @@ -373,6 +386,56 @@ def add_metrics(self, **kwargs: str) -> "Card":
self._eval_results[metric] = value
return self

@property
def model(self):
model = self._load_model(self._model)
if model is not self._model:
self._model = model
return model

@model.setter
def model(self, model):
self._model = model

@model.deleter
def model(self):
del self._model

def _load_model(self, model: Any) -> Any:
"""Loads the model if provided a file path, if already a model instance,
return it unmodified.

Parameters
----------
model : pathlib.path, str, or sklearn estimator
Path/str or the actual model instance. If a Path or str, loads the model on first call.

Returns
-------
model : object
Model instance.

"""
if not isinstance(model, (Path, str)):
return model

model_path = Path(model)
if not model_path.exists():
raise ValueError("Model file does not exist")

if model_path.suffix in (".pkl", ".pickle"):
model = joblib.load(model_path)
elif model_path.suffix == ".skops":
model = load(model_path)
else:
msg = (
f"Cannot interpret model suffix {model_path.suffix}, should be"
" '.pkl', '.pickle' or '.skops'"
)
raise ValueError(msg)

return model

def _generate_card(self) -> ModelCard:
"""Generate the ModelCard object

Expand Down Expand Up @@ -403,18 +466,23 @@ def _generate_card(self) -> ModelCard:
)
else:
template_sections["get_started_code"] = (
"import joblib\nimport json\nimport pandas as pd\nclf ="
f' joblib.load({model_file})\nwith open("config.json") as'
"import joblib\nimport json\nimport pandas as"
" pd\nclf ="
f" joblib.load({model_file})\nwith"
' open("config.json") as'
" f:\n "
" config ="
" json.load(f)\n"
'clf.predict(pd.DataFrame.from_dict(config["sklearn"]["example_input"]))'
)
if self.model_diagram is True:
model_plot_div = re.sub(r"\n\s+", "", str(estimator_html_repr(self.model)))
model_plot_div = re.sub(
r"\n\s+", "", str(estimator_html_repr(self.model))
)
if model_plot_div.count("sk-top-container") == 1:
model_plot_div = model_plot_div.replace(
"sk-top-container", 'sk-top-container" style="overflow: auto;'
"sk-top-container",
'sk-top-container" style="overflow: auto;',
)
model_plot: str | None = model_plot_div
else:
Expand All @@ -439,7 +507,9 @@ def _generate_card(self) -> ModelCard:
f"{tmpdirname}/temporary_template.md",
)
# create a temporary template with the additional plots
template_sections["template_path"] = f"{tmpdirname}/temporary_template.md"
template_sections[
"template_path"
] = f"{tmpdirname}/temporary_template.md"
# add extra sections at the end of the template
with open(template_sections["template_path"], "a") as template:
if self._extra_sections:
Expand Down Expand Up @@ -520,13 +590,17 @@ def __repr__(self) -> str:
model = getattr(self, "model", None)
if model:
model_str = self._strip_blank(repr(model))
model_repr = aRepr.repr(f" model={model_str},").strip('"').strip("'")
model_repr = (
aRepr.repr(f" model={model_str},").strip('"').strip("'")
)
else:
model_repr = None

# metadata
metadata_reprs = []
for key, val in self.metadata.to_dict().items() if self.metadata else {}:
for key, val in (
self.metadata.to_dict().items() if self.metadata else {}
):
if key == "widget":
metadata_reprs.append(" metadata.widget={...},")
continue
Expand All @@ -540,14 +614,18 @@ def __repr__(self) -> str:
template_reprs = []
for key, val in self._template_sections.items():
val = self._strip_blank(repr(val))
template_reprs.append(aRepr.repr(f" {key}={val},").strip('"').strip("'"))
template_reprs.append(
aRepr.repr(f" {key}={val},").strip('"').strip("'")
)
template_repr = "\n".join(template_reprs)

# figures
figure_reprs = []
for key, val in self._extra_sections:
val = self._strip_blank(repr(val))
figure_reprs.append(aRepr.repr(f" {key}={val},").strip('"').strip("'"))
figure_reprs.append(
aRepr.repr(f" {key}={val},").strip('"').strip("'")
)
figure_repr = "\n".join(figure_reprs)

complete_repr = "Card(\n"
Expand Down
Loading