diff --git a/skops/card/_model_card.py b/skops/card/_model_card.py index 15a8fe04..5d9942f5 100644 --- a/skops/card/_model_card.py +++ b/skops/card/_model_card.py @@ -380,8 +380,18 @@ class Card: ``Path``/``str`` of the model or the actual model instance that will be documented. If a ``Path`` or ``str`` is provided, model will be loaded. - model_diagram: bool, default=True - Set to True if model diagram should be plotted in the card. + model_diagram: bool or "auto" or str, default="auto" + If using the skops template, setting this to ``True`` or ``"auto"`` will + add the model diagram, as generated by sckit-learn, to the default + section, i.e "Model description/Training Procedure/Model Plot". Passing + a string to ``model_diagram`` will instead use that string as the + section name for the diagram. Set to ``False`` to not include the model + diagram. + + If using a non-skops template, passing ``"auto"`` won't add the model + diagram because there is no pre-defined section to put it. The model + diagram can, however, always be added later using + :meth:`Card.add_model_plot`. metadata: ModelCardData, optional :class:`huggingface_hub.ModelCardData` object. The contents of this @@ -481,13 +491,12 @@ class Card: def __init__( self, model, - model_diagram: bool = True, + model_diagram: bool | Literal["auto"] | str = "auto", metadata: ModelCardData | None = None, template: Literal["skops"] | dict[str, str] | None = "skops", trusted: bool = False, ) -> None: self.model = model - self.model_diagram = model_diagram self.metadata = metadata or ModelCardData() self.template = template self.trusted = trusted @@ -495,13 +504,23 @@ def __init__( self._data: dict[str, Section] = {} self._metrics: dict[str, str | float | int] = {} - self._populate_template() + self._populate_template(model_diagram=model_diagram) - def _populate_template(self): - """If initialized with a template, use it to populate the card.""" - if not self.template: - return + def _populate_template(self, model_diagram: bool | Literal["auto"] | str): + """If initialized with a template, use it to populate the card. + + Parameters + ---------- + model_diagram: bool or "auto" or str + If using the default template, ``"auto"`` and ``True`` will add the + diagram in its default section. If using a custom template, + ``"auto"`` will not add the diagram, and passing ``True`` will + result in an error. For either, passing ``False`` will result in the + model diagram being omitted, and passing a string (other than + ``"auto"``) will put the model diagram into a section corresponding + to that string. + """ if isinstance(self.template, str) and (self.template not in VALID_TEMPLATES): valid_templates = ", ".join(f"'{val}'" for val in sorted(VALID_TEMPLATES)) msg = ( @@ -510,15 +529,29 @@ def _populate_template(self): ) raise ValueError(msg) + # default template if self.template == Templates.skops.value: self.add(**SKOPS_TEMPLATE) # for the skops template, automatically add some default sections - self.add_model_plot() self.add_hyperparams() self.add_get_started_code() - elif isinstance(self.template, Mapping): + + if (model_diagram is True) or (model_diagram == "auto"): + self.add_model_plot() + elif isinstance(model_diagram, str): + self.add_model_plot(section=model_diagram) + return + + # non-default template + if isinstance(self.template, Mapping): self.add(**self.template) + if isinstance(model_diagram, str) and (model_diagram != "auto"): + self.add_model_plot(section=model_diagram) + elif model_diagram is True: + # will trigger an error + self.add_model_plot() + def get_model(self) -> Any: """Returns sklearn estimator object. @@ -789,9 +822,6 @@ def add_model_plot( self : object Card object. """ - if not self.model_diagram: - return self - if section is None: if self.template == Templates.skops.value: section = "Model description/Training Procedure/Model Plot" diff --git a/skops/card/tests/test_card.py b/skops/card/tests/test_card.py index 0a97a303..f9d8e0d8 100644 --- a/skops/card/tests/test_card.py +++ b/skops/card/tests/test_card.py @@ -177,6 +177,26 @@ def test_model_diagram_false(self): ).content assert result == "The model plot is below." + def test_model_diagram_str(self): + # if passing a str, use that as the section name + model = fit_model() + other_section_name = "Here is the model diagram" + model_card = Card(model, model_diagram=other_section_name) + + # first check that default section only contains placeholder + result = model_card.select( + "Model description/Training Procedure/Model Plot" + ).format() + assert result == "The model plot is below." + + # now check that the actual model diagram is in the other section + result = model_card.select(other_section_name).format() + assert result.startswith("The model plot is below.\n\n