diff --git a/docs/changes.rst b/docs/changes.rst index fb01aa3d..b92bcc5f 100644 --- a/docs/changes.rst +++ b/docs/changes.rst @@ -12,6 +12,8 @@ skops Changelog v0.7 ---- +- Add ability to copy plots on :meth:`.Card.save` so that they can be + referenced in the model card. :pr:`330` by :user:`Thomas Lazarus `. - `compression` and `compresslevel` from :class:`~zipfile.ZipFile` are now exposed to the user via :func:`.io.dumps` and :func:`.io.dump`. :pr:`345` by `Adrin Jalali`_. diff --git a/skops/card/_model_card.py b/skops/card/_model_card.py index 754c03f2..32ba9b8c 100644 --- a/skops/card/_model_card.py +++ b/skops/card/_model_card.py @@ -2,6 +2,7 @@ import json import re +import shutil import sys import textwrap import zipfile @@ -1318,7 +1319,10 @@ def _generate_metadata(self, metadata: ModelCardData) -> Iterator[str]: yield aRepr.repr(f"metadata.{key}={val},").strip('"').strip("'") def _generate_content( - self, data: dict[str, Section], depth: int = 1 + self, + data: dict[str, Section], + depth: int = 1, + destination_path: Path | None = None, ) -> Iterator[str]: """Yield title and (formatted) contents. @@ -1336,8 +1340,15 @@ def _generate_content( yield section.format() + if destination_path is not None and isinstance(section, PlotSection): + shutil.copy(section.path, destination_path) + if section.subsections: - yield from self._generate_content(section.subsections, depth=depth + 1) + yield from self._generate_content( + section.subsections, + depth=depth + 1, + destination_path=destination_path, + ) def _iterate_content( self, data: dict[str, Section], parent_section: str = "" @@ -1405,19 +1416,21 @@ def __repr__(self) -> str: complete_repr += ")" return complete_repr - def _generate_card(self) -> Iterator[str]: + def _generate_card(self, destination_path: Path | None = None) -> Iterator[str]: """Yield sections of the model card, including the metadata.""" if self.metadata.to_dict(): yield f"---\n{self.metadata.to_yaml()}\n---" - for line in self._generate_content(self._data): + for line in self._generate_content( + self._data, destination_path=destination_path + ): if line: yield "\n" + line # add an empty line add the end yield "" - def save(self, path: str | Path) -> None: + def save(self, path: str | Path, copy_files: bool = False) -> None: """Save the model card. This method renders the model card in markdown format and then saves it @@ -1425,16 +1438,24 @@ def save(self, path: str | Path) -> None: Parameters ---------- - path: str, or Path + path: Path Filepath to save your card. + plot_path: str + Filepath to save the plots. Use this when saving the model card before creating the + repository. Without this path the README will have an absolute path to the plot that + won't exist in the repository. + Notes ----- The keys in model card metadata can be seen `here `__. """ with open(path, "w", encoding="utf-8") as f: - f.write("\n".join(self._generate_card())) + if not isinstance(path, Path): + path = Path(path) + destination_path = path.parent if copy_files else None + f.write("\n".join(self._generate_card(destination_path=destination_path))) def render(self) -> str: """Render the final model card as a string. diff --git a/skops/card/tests/test_card.py b/skops/card/tests/test_card.py index 32f15aab..9e625852 100644 --- a/skops/card/tests/test_card.py +++ b/skops/card/tests/test_card.py @@ -1916,3 +1916,25 @@ def test_toc_with_invisible_section(self, card): ] assert toc == "\n".join(exptected_toc) + + +class TestCardSaveWithPlots: + def test_copy_plots(self, destination_path, model_card): + import matplotlib.pyplot as plt + + with tempfile.TemporaryDirectory(prefix="skops-test-plots") as plot_path: + plt.plot([4, 5, 6, 7]) + fig_1_path = Path(plot_path) / "fig1.png" + plt.savefig(fig_1_path) + model_card = model_card.add_plot(fig1=fig_1_path) + + plt.plot([7, 6, 5, 4]) + fig_2_path = "fig2.png" + plt.savefig(fig_2_path) + model_card = model_card.add_plot(fig2=fig_2_path) + + model_card.save(Path(destination_path) / "README.md", copy_files=True) + + assert (Path(destination_path) / "README.md").exists() + assert (Path(destination_path) / "fig1.png").exists() + assert (Path(destination_path) / "fig2.png").exists()