Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion skops/card/_model_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ def __init__(
trusted: bool = False,
) -> None:
self.model = model
self.model_diagram = model_diagram
self._model_diagram = model_diagram
self.metadata = metadata or ModelCardData()
self.template = template
self.trusted = trusted
Expand Down Expand Up @@ -704,12 +704,46 @@ def _add_single(self, key: str, val: Formattable | str) -> Section:
if leaf_node_name in section:
# entry exists, only overwrite content
section[leaf_node_name].content = val
# if node already existed but was invisible, make it visible
section[leaf_node_name].visible = True
else:
# entry does not exist, create a new one
section[leaf_node_name] = Section(title=leaf_node_name, content=val)

return section[leaf_node_name]

@property
def model_diagram(self) -> bool:
return self._model_diagram

@model_diagram.setter
def model_diagram(self, value: bool) -> None:
if self._model_diagram is value:
# nothing to change, early return
return

self._model_diagram = value

# If we use the skops template, we know what section to add or remove
# when model_diagram changes values. If not, we don't know and thus need
# to skip this step.
if self.template != Templates.skops.value:
msg = (
"You are trying to deactivate the model diagram, which does not work "
"when using a custom template. Instead, delete the diagram directly by "
"calling 'model_card.delete(<name-of-model-diagram-section>)"
)
raise ValueError(msg)

section_name = "Model description/Training Procedure/Model Plot"
if not value: # don't show model diagram
section = self.select(section_name)
section.visible = False
else: # do show model diagram
self._add_model_plot(
self.get_model(), section=section_name, description=None
)

def add_model_plot(
self,
section: str | None = None,
Expand Down
103 changes: 103 additions & 0 deletions skops/card/tests/test_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,20 @@ def test_model_diagram_false(self):
).content
assert result == "The model plot is below."

def test_model_diagram_false_add_manually(self):
# Here we set model_diagram=False but then later explicitly call
# model_card.add_model_plot. The current behavior is not to show the
# model plot in this case. It could be discussed if that's the expected
# outcome or not.
model = fit_model()
model_card = Card(model, model_diagram=False)
model_card.add_model_plot()

result = model_card.select(
"Model description/Training Procedure/Model Plot"
).content
assert result == "The model plot is below."

def test_other_section(self, model_card):
model_card.add_model_plot(section="Other section")
result = model_card.select("Other section").content
Expand All @@ -191,6 +205,34 @@ def test_other_description(self, model_card):
).content
assert result.startswith("Awesome diagram below\n\n<style>#sk-")

@pytest.mark.parametrize("template", CUSTOM_TEMPLATES)
def test_custom_template_manually_adding_works(self, model_card, template):
model = fit_model()
model_card = Card(model, template=template)
model_card.add_model_plot(section="My model diagram")

result = model_card.select("My model diagram").content
assert result.startswith("<style>#sk-")
assert "<style>" in result
assert result.endswith(
"<pre>LinearRegression()</pre></div></div></div></div></div>"
)

@pytest.mark.parametrize("template", CUSTOM_TEMPLATES)
def test_custom_template_description_manually_adding_works(
self, model_card, template
):
model = fit_model()
model_card = Card(model, template=template)
model_card.add_model_plot(section="My model diagram", description="Tada")

result = model_card.select("My model diagram").content
assert result.startswith("Tada\n\n<style>#sk-")
assert "<style>" in result
assert result.endswith(
"<pre>LinearRegression()</pre></div></div></div></div></div>"
)

@pytest.mark.parametrize("template", CUSTOM_TEMPLATES)
def test_custom_template_no_section_raises(self, template):
model = fit_model()
Expand All @@ -215,6 +257,67 @@ def test_add_twice(self, model_card):
# thus compare everything but the numbers
assert re.split(r"\d+", text1) == re.split(r"\d+", text2)

def test_setting_model_diagram_false_removes_it(self, model_card):
model_card.model_diagram = False
rendered = model_card.render()

# don't compare whole text, as it's quite long and non-deterministic
assert "The model plot is below.\n\n<style>#sk-" not in rendered
assert "<style>" not in rendered
assert (
"<pre>LinearRegression()</pre></div></div></div></div></div>"
not in rendered
)

@pytest.mark.xfail(strict=True)
def test_setting_model_diagram_false_other_section(self, model_card):
# when the model diagram is in another section, this does not work
# currently, thus we xfail this test
model_card.add_model_plot(section="Other section")
model_card.model_diagram = False
rendered = model_card.render()

assert "The model plot is below.\n\n<style>#sk-" not in rendered
assert "<style>" not in rendered
assert (
"<pre>LinearRegression()</pre></div></div></div></div></div>"
not in rendered
)

@pytest.mark.parametrize("template", CUSTOM_TEMPLATES)
def test_setting_model_diagram_false_custom_template(self, model_card, template):
model = fit_model()
model_card = Card(model, template=template, model_diagram=True)
model_card.add_model_plot("A beautiful estimator")

match = "You are trying to deactivate the model diagram"
with pytest.raises(ValueError, match=match):
model_card.model_diagram = False

def test_setting_model_diagram_false_twice_no_error(self, model_card):
# check that this does not raise
model_card.model_diagram = False
model_card.model_diagram = False
rendered = model_card.render()

# don't compare whole text, as it's quite long and non-deterministic
assert "The model plot is below.\n\n<style>#sk-" not in rendered
assert "<style>" not in rendered
assert (
"<pre>LinearRegression()</pre></div></div></div></div></div>"
not in rendered
)

def test_setting_model_diagram_false_then_true(self, model_card):
model_card.model_diagram = False
model_card.model_diagram = True
rendered = model_card.render()

# don't compare whole text, as it's quite long and non-deterministic
assert "<style>#sk-" in rendered
assert "<style>" in rendered
assert "<pre>LinearRegression()</pre></div></div></div></div></div>" in rendered


def _strip_multiple_chars(text, char):
# utility function needed to compare tables across systems
Expand Down