diff --git a/skops/card/_model_card.py b/skops/card/_model_card.py index 1e54fc0e..754c03f2 100644 --- a/skops/card/_model_card.py +++ b/skops/card/_model_card.py @@ -7,6 +7,8 @@ import zipfile from collections.abc import Mapping from dataclasses import dataclass, field +from functools import cached_property +from hashlib import sha256 from pathlib import Path from reprlib import Repr from typing import Any, Iterator, Literal, Sequence, Union @@ -503,6 +505,7 @@ def __init__( self._data: dict[str, Section] = {} self._metrics: dict[str, str | float | int] = {} + self._model_hash = "" self._populate_template(model_diagram=model_diagram) @@ -564,9 +567,24 @@ def get_model(self) -> Any: The model instance. """ + if isinstance(self.model, (str, Path)) and hasattr(self, "_model"): + hash_obj = sha256() + buf_size = 2**20 # load in chunks to save memory + with open(self.model, "rb") as f: + for chunk in iter(lambda: f.read(buf_size), b""): + hash_obj.update(chunk) + model_hash = hash_obj.hexdigest() + + # if hash changed, invalidate cache by deleting attribute + if model_hash != self._model_hash: + del self._model + self._model_hash = model_hash + + return self._model + + @cached_property + def _model(self): model = _load_model(self.model, self.trusted) - # Ideally, we would only call the method below if we *know* that the - # model has changed, but at the moment we have no way of knowing that return model def add(self, **kwargs: str) -> Self: diff --git a/skops/card/tests/test_card.py b/skops/card/tests/test_card.py index 7a3dfb27..32f15aab 100644 --- a/skops/card/tests/test_card.py +++ b/skops/card/tests/test_card.py @@ -4,6 +4,7 @@ import tempfile import textwrap from pathlib import Path +from unittest import mock import numpy as np import pytest @@ -25,7 +26,7 @@ TableSection, _load_model, ) -from skops.io import dump +from skops.io import dump, load from skops.utils.importutils import import_or_raise @@ -145,6 +146,34 @@ def test_save_model_card(destination_path, model_card): assert (Path(destination_path) / "README.md").exists() +def test_model_caching( + skops_model_card_metadata_from_config, iris_skops_file, destination_path +): + """Tests that the model card caches the model to avoid loading it multiple times""" + + new_model = LogisticRegression(random_state=4321) + # mock _load_model, it still loads the model but we can track call count + mock_load_model = mock.Mock(side_effect=load) + card = Card(iris_skops_file, metadata=metadata_from_config(destination_path)) + with mock.patch("skops.card._model_card._load_model", mock_load_model): + model1 = card.get_model() + model2 = card.get_model() + assert model1 is model2 + # model is cached, hence _load_model is not called + mock_load_model.assert_not_called() + + # override model with new model + dump(new_model, card.model) + + model3 = card.get_model() + assert mock_load_model.call_count == 1 + assert model3.random_state == 4321 + model4 = card.get_model() + + assert model3 is model4 + assert mock_load_model.call_count == 1 # cached call + + CUSTOM_TEMPLATES = [None, {}, {"A Title", "Another Title", "A Title/A Section"}] # type: ignore