diff --git a/skops/card/_model_card.py b/skops/card/_model_card.py index 74fa377a..33bef3be 100644 --- a/skops/card/_model_card.py +++ b/skops/card/_model_card.py @@ -146,8 +146,9 @@ def metadata_from_config(config_path: Union[str, Path]) -> ModelCardData: with open(config_path) as f: config = json.load(f) - - card_data = ModelCardData() + card_data = ModelCardData( + model_format=config.get("sklearn", {}).get("model_format", {}) + ) card_data.library_name = "sklearn" card_data.tags = ["sklearn", "skops"] task = config.get("sklearn", {}).get("task", None) diff --git a/skops/card/tests/test_card.py b/skops/card/tests/test_card.py index 92e1c68c..df00a8c6 100644 --- a/skops/card/tests/test_card.py +++ b/skops/card/tests/test_card.py @@ -898,6 +898,18 @@ def test_metadata_from_config_tabular_data( for tag in ["sklearn", "skops", "tabular-classification"]: assert tag in metadata["tags"] + def test_metadata_model_format_pkl( + self, pkl_model_card_metadata_from_config, destination_path + ): + metadata = metadata_load(local_path=Path(destination_path) / "README.md") + assert metadata["model_format"] == "pickle" + + def test_metadata_model_format_skops( + self, skops_model_card_metadata_from_config, destination_path + ): + metadata = metadata_load(local_path=Path(destination_path) / "README.md") + assert metadata["model_format"] == "skops" + @pytest.mark.xfail(reason="dynamic adjustment when model changes not implemented yet") class TestModelDynamicUpdate: