Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 44 additions & 14 deletions skops/card/_model_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,18 @@ class Card:
``Path``/``str`` of the model or the actual model instance that will be
documented. If a ``Path`` or ``str`` is provided, model will be loaded.

model_diagram: bool, default=True
Set to True if model diagram should be plotted in the card.
model_diagram: bool or "auto" or str, default="auto"
If using the skops template, setting this to ``True`` or ``"auto"`` will
add the model diagram, as generated by sckit-learn, to the default
section, i.e "Model description/Training Procedure/Model Plot". Passing
a string to ``model_diagram`` will instead use that string as the
section name for the diagram. Set to ``False`` to not include the model
diagram.

If using a non-skops template, passing ``"auto"`` won't add the model
diagram because there is no pre-defined section to put it. The model
diagram can, however, always be added later using
:meth:`Card.add_model_plot`.

metadata: ModelCardData, optional
:class:`huggingface_hub.ModelCardData` object. The contents of this
Expand Down Expand Up @@ -481,27 +491,36 @@ class Card:
def __init__(
self,
model,
model_diagram: bool = True,
model_diagram: bool | Literal["auto"] | str = "auto",
metadata: ModelCardData | None = None,
template: Literal["skops"] | dict[str, str] | None = "skops",
trusted: bool = False,
) -> None:
self.model = model
self.model_diagram = model_diagram
self.metadata = metadata or ModelCardData()
self.template = template
self.trusted = trusted

self._data: dict[str, Section] = {}
self._metrics: dict[str, str | float | int] = {}

self._populate_template()
self._populate_template(model_diagram=model_diagram)

def _populate_template(self):
"""If initialized with a template, use it to populate the card."""
if not self.template:
return
def _populate_template(self, model_diagram: bool | Literal["auto"] | str):
"""If initialized with a template, use it to populate the card.

Parameters
----------
model_diagram: bool or "auto" or str
If using the default template, ``"auto"`` and ``True`` will add the
diagram in its default section. If using a custom template,
``"auto"`` will not add the diagram, and passing ``True`` will
result in an error. For either, passing ``False`` will result in the
model diagram being omitted, and passing a string (other than
``"auto"``) will put the model diagram into a section corresponding
to that string.

"""
if isinstance(self.template, str) and (self.template not in VALID_TEMPLATES):
valid_templates = ", ".join(f"'{val}'" for val in sorted(VALID_TEMPLATES))
msg = (
Expand All @@ -510,15 +529,29 @@ def _populate_template(self):
)
raise ValueError(msg)

# default template
if self.template == Templates.skops.value:
self.add(**SKOPS_TEMPLATE)
# for the skops template, automatically add some default sections
self.add_model_plot()
self.add_hyperparams()
self.add_get_started_code()
elif isinstance(self.template, Mapping):

if (model_diagram is True) or (model_diagram == "auto"):
self.add_model_plot()
elif isinstance(model_diagram, str):
self.add_model_plot(section=model_diagram)
return

# non-default template
if isinstance(self.template, Mapping):
self.add(**self.template)

if isinstance(model_diagram, str) and (model_diagram != "auto"):
self.add_model_plot(section=model_diagram)
elif model_diagram is True:
# will trigger an error
self.add_model_plot()

def get_model(self) -> Any:
"""Returns sklearn estimator object.

Expand Down Expand Up @@ -789,9 +822,6 @@ def add_model_plot(
self : object
Card object.
"""
if not self.model_diagram:
return self

if section is None:
if self.template == Templates.skops.value:
section = "Model description/Training Procedure/Model Plot"
Expand Down
61 changes: 61 additions & 0 deletions skops/card/tests/test_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,26 @@ def test_model_diagram_false(self):
).content
assert result == "The model plot is below."

def test_model_diagram_str(self):
# if passing a str, use that as the section name
model = fit_model()
other_section_name = "Here is the model diagram"
model_card = Card(model, model_diagram=other_section_name)

# first check that default section only contains placeholder
result = model_card.select(
"Model description/Training Procedure/Model Plot"
).format()
assert result == "The model plot is below."

# now check that the actual model diagram is in the other section
result = model_card.select(other_section_name).format()
assert result.startswith("The model plot is below.\n\n<style>#sk-")
assert "<style>" in result
assert result.endswith(
"<pre>LinearRegression()</pre></div></div></div></div></div>"
)

def test_other_section(self, model_card):
model_card.add_model_plot(section="Other section")
result = model_card.select("Other section").content
Expand Down Expand Up @@ -204,6 +224,47 @@ def test_custom_template_no_section_raises(self, template):
with pytest.raises(ValueError, match=msg):
model_card.add_model_plot()

@pytest.mark.parametrize("template", CUSTOM_TEMPLATES)
def test_custom_template_init_str_works(self, template):
model = fit_model()
section_name = "Here is the model diagram"
model_card = Card(model, template=template, model_diagram=section_name)

result = model_card.select(section_name).format()
assert result.startswith("<style>#sk-")
assert "<style>" in result
assert result.endswith(
"<pre>LinearRegression()</pre></div></div></div></div></div>"
)

def test_default_template_and_model_diagram_true(self, model_card):
# setting model_diagram=True should not change anything vs auto with the
# default template
model = fit_model()
model_card = Card(model, model_diagram=True)
result = model_card.select(
"Model description/Training Procedure/Model Plot"
).content
# don't compare whole text, as it's quite long and non-deterministic
assert result.startswith("The model plot is below.\n\n<style>#sk-")
assert "<style>" in result
assert result.endswith(
"<pre>LinearRegression()</pre></div></div></div></div></div>"
)

@pytest.mark.parametrize("template", CUSTOM_TEMPLATES)
def test_custom_template_and_model_diagram_true(self, model_card, template):
# in contrast to the previous test, when setting model_diagram=True but
# using a custom template, we expect an error during initialization of
# the model cord
model = fit_model()
msg = (
"You are trying to add a model plot but you're using a custom template, "
"please pass the 'section' argument to determine where to put the content"
)
with pytest.raises(ValueError, match=msg):
Card(model, template=template, model_diagram=True)

def test_add_twice(self, model_card):
# it's possible to add the section twice, even if it doesn't make a lot
# of sense
Expand Down