diff --git a/skops/card/_model_card.py b/skops/card/_model_card.py index d5e0e55c..40254a2e 100644 --- a/skops/card/_model_card.py +++ b/skops/card/_model_card.py @@ -457,7 +457,7 @@ def __init__( trusted: bool = False, ) -> None: self.model = model - self.model_diagram = model_diagram + self._model_diagram = model_diagram self.metadata = metadata or ModelCardData() self.template = template self.trusted = trusted @@ -704,12 +704,46 @@ def _add_single(self, key: str, val: Formattable | str) -> Section: if leaf_node_name in section: # entry exists, only overwrite content section[leaf_node_name].content = val + # if node already existed but was invisible, make it visible + section[leaf_node_name].visible = True else: # entry does not exist, create a new one section[leaf_node_name] = Section(title=leaf_node_name, content=val) return section[leaf_node_name] + @property + def model_diagram(self) -> bool: + return self._model_diagram + + @model_diagram.setter + def model_diagram(self, value: bool) -> None: + if self._model_diagram is value: + # nothing to change, early return + return + + self._model_diagram = value + + # If we use the skops template, we know what section to add or remove + # when model_diagram changes values. If not, we don't know and thus need + # to skip this step. + if self.template != Templates.skops.value: + msg = ( + "You are trying to deactivate the model diagram, which does not work " + "when using a custom template. Instead, delete the diagram directly by " + "calling 'model_card.delete()" + ) + raise ValueError(msg) + + section_name = "Model description/Training Procedure/Model Plot" + if not value: # don't show model diagram + section = self.select(section_name) + section.visible = False + else: # do show model diagram + self._add_model_plot( + self.get_model(), section=section_name, description=None + ) + def add_model_plot( self, section: str | None = None, diff --git a/skops/card/tests/test_card.py b/skops/card/tests/test_card.py index 9ad94277..e640ddd4 100644 --- a/skops/card/tests/test_card.py +++ b/skops/card/tests/test_card.py @@ -175,6 +175,20 @@ def test_model_diagram_false(self): ).content assert result == "The model plot is below." + def test_model_diagram_false_add_manually(self): + # Here we set model_diagram=False but then later explicitly call + # model_card.add_model_plot. The current behavior is not to show the + # model plot in this case. It could be discussed if that's the expected + # outcome or not. + model = fit_model() + model_card = Card(model, model_diagram=False) + model_card.add_model_plot() + + result = model_card.select( + "Model description/Training Procedure/Model Plot" + ).content + assert result == "The model plot is below." + def test_other_section(self, model_card): model_card.add_model_plot(section="Other section") result = model_card.select("Other section").content @@ -191,6 +205,34 @@ def test_other_description(self, model_card): ).content assert result.startswith("Awesome diagram below\n\n