Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/changes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ skops Changelog

v0.7
----
- Add ability to copy plots on :meth:`.Card.save` so that they can be
referenced in the model card. :pr:`330` by :user:`Thomas Lazarus <lazarust>`.
- `compression` and `compresslevel` from :class:`~zipfile.ZipFile` are now
exposed to the user via :func:`.io.dumps` and :func:`.io.dump`. :pr:`345` by
`Adrin Jalali`_.
Expand Down
35 changes: 28 additions & 7 deletions skops/card/_model_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import json
import re
import shutil
import sys
import textwrap
import zipfile
Expand Down Expand Up @@ -1318,7 +1319,10 @@ def _generate_metadata(self, metadata: ModelCardData) -> Iterator[str]:
yield aRepr.repr(f"metadata.{key}={val},").strip('"').strip("'")

def _generate_content(
self, data: dict[str, Section], depth: int = 1
self,
data: dict[str, Section],
depth: int = 1,
destination_path: Path | None = None,
) -> Iterator[str]:
"""Yield title and (formatted) contents.

Expand All @@ -1336,8 +1340,15 @@ def _generate_content(

yield section.format()

if destination_path is not None and isinstance(section, PlotSection):
shutil.copy(section.path, destination_path)

if section.subsections:
yield from self._generate_content(section.subsections, depth=depth + 1)
yield from self._generate_content(
section.subsections,
depth=depth + 1,
destination_path=destination_path,
)

def _iterate_content(
self, data: dict[str, Section], parent_section: str = ""
Expand Down Expand Up @@ -1405,36 +1416,46 @@ def __repr__(self) -> str:
complete_repr += ")"
return complete_repr

def _generate_card(self) -> Iterator[str]:
def _generate_card(self, destination_path: Path | None = None) -> Iterator[str]:
"""Yield sections of the model card, including the metadata."""
if self.metadata.to_dict():
yield f"---\n{self.metadata.to_yaml()}\n---"

for line in self._generate_content(self._data):
for line in self._generate_content(
self._data, destination_path=destination_path
):
if line:
yield "\n" + line

# add an empty line add the end
yield ""

def save(self, path: str | Path) -> None:
def save(self, path: str | Path, copy_files: bool = False) -> None:
"""Save the model card.

This method renders the model card in markdown format and then saves it
as the specified file.

Parameters
----------
path: str, or Path
path: Path
Filepath to save your card.

plot_path: str
Filepath to save the plots. Use this when saving the model card before creating the
repository. Without this path the README will have an absolute path to the plot that
won't exist in the repository.

Notes
-----
The keys in model card metadata can be seen `here
<https://huggingface.co/docs/hub/models-cards#model-card-metadata>`__.
"""
with open(path, "w", encoding="utf-8") as f:
f.write("\n".join(self._generate_card()))
if not isinstance(path, Path):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a bit of redundancy here: You cast path to a Path object here but _generate_card and _generate_content both accept str and the latter does Path(destination_path). So either you just pass path without conversion, or those methods should not accept str and don't perform a 2nd round of conversion.

path = Path(path)
destination_path = path.parent if copy_files else None
f.write("\n".join(self._generate_card(destination_path=destination_path)))

def render(self) -> str:
"""Render the final model card as a string.
Expand Down
22 changes: 22 additions & 0 deletions skops/card/tests/test_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -1916,3 +1916,25 @@ def test_toc_with_invisible_section(self, card):
]

assert toc == "\n".join(exptected_toc)


class TestCardSaveWithPlots:
def test_copy_plots(self, destination_path, model_card):
import matplotlib.pyplot as plt

with tempfile.TemporaryDirectory(prefix="skops-test-plots") as plot_path:
plt.plot([4, 5, 6, 7])
fig_1_path = Path(plot_path) / "fig1.png"
plt.savefig(fig_1_path)
model_card = model_card.add_plot(fig1=fig_1_path)

plt.plot([7, 6, 5, 4])
fig_2_path = "fig2.png"
plt.savefig(fig_2_path)
model_card = model_card.add_plot(fig2=fig_2_path)

model_card.save(Path(destination_path) / "README.md", copy_files=True)

assert (Path(destination_path) / "README.md").exists()
assert (Path(destination_path) / "fig1.png").exists()
assert (Path(destination_path) / "fig2.png").exists()