Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/changes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ skops Changelog

v0.7
----

- `compression` and `compresslevel` from :class:`~zipfile.ZipFile` are now
exposed to the user via :func:`.io.dumps` and :func:`.io.dump`. :pr:`345` by
`Adrin Jalali`_.

v0.6
----
Expand Down
16 changes: 16 additions & 0 deletions docs/persistence.rst
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,22 @@ you have custom functions (say, a custom function to be used with
most ``numpy`` and ``scipy`` functions should work. Therefore, you can save
objects having references to functions such as ``numpy.sqrt``.

Compression
~~~~~~~~~~~

If file size is an issue, you can compress the file by setting the
``compression`` and ``compresslevel`` arguments to :func:`skops.io.dump` and
:func:`skops.io.dumps`. For example, to compress the file using ``zlib`` with
level 9:

.. code:: python

from zipfile import ZIP_DEFLATED
dump(clf, "my-model.skops", compression=ZIP_DEFLATED, compresslevel=9)

Check the documentation of these two arguments under :class:`zipfile.ZipFile`
for more details.

Command Line Interface
######################

Expand Down
58 changes: 39 additions & 19 deletions skops/io/_persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import json
from pathlib import Path
from typing import Any, BinaryIO, Sequence
from zipfile import ZipFile
from zipfile import ZIP_STORED, ZipFile

import skops

Expand All @@ -26,10 +26,12 @@
NODE_TYPE_MAPPING.update(module.NODE_TYPE_MAPPING)


def _save(obj: Any) -> io.BytesIO:
def _save(obj: Any, compression: int, compresslevel: int | None) -> io.BytesIO:
buffer = io.BytesIO()

with ZipFile(buffer, "w") as zip_file:
with ZipFile(
buffer, "w", compression=compression, compresslevel=compresslevel
) as zip_file:
save_context = SaveContext(zip_file=zip_file)
state = get_state(obj, save_context)
save_context.clear_memo()
Expand All @@ -41,19 +43,19 @@ def _save(obj: Any) -> io.BytesIO:
return buffer


def dump(obj: Any, file: str | Path | BinaryIO) -> None:
def dump(
obj: Any,
file: str | Path | BinaryIO,
*,
compression: int = ZIP_STORED,
compresslevel: int | None = None,
) -> None:
"""Save an object using the skops persistence format.

Skops aims at providing a secure persistence feature that does not rely on
:mod:`pickle`, which is inherently insecure. For more information, please
visit the :ref:`persistence` documentation.

.. warning::

This feature is heavily under development, which means the API is
unstable and there might be security issues at the moment. Therefore,
use caution when loading files from sources you don't trust.

Parameters
----------
obj: object
Expand All @@ -64,8 +66,19 @@ def dump(obj: Any, file: str | Path | BinaryIO) -> None:
convention, we recommend to use the ".skops" file extension, e.g.
``save(model, "my-model.skops")``.

compression: int, default=zipfile.ZIP_STORED
The compression method to use. See :class:`zipfile.ZipFile` for more
information.

.. versionadded:: 0.7

compresslevel: int, default=None
The compression level to use. See :class:`zipfile.ZipFile` for more
information.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to make a recommendation (here or in the docs) about what settings to use? E.g. "if file size is no issue, use the defaults, as they are the fastest. Otherwise, use ".

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added something in the docs.


.. versionadded:: 0.7
"""
buffer = _save(obj)
buffer = _save(obj, compression=compression, compresslevel=compresslevel)

if isinstance(file, (str, Path)):
with open(file, "wb") as f:
Expand All @@ -74,22 +87,29 @@ def dump(obj: Any, file: str | Path | BinaryIO) -> None:
file.write(buffer.getbuffer())


def dumps(obj: Any) -> bytes:
def dumps(
obj: Any, *, compression: int = ZIP_STORED, compresslevel: int | None = None
) -> bytes:
"""Save an object using the skops persistence format as a bytes object.

.. warning::

This feature is heavily under development, which means the API is
unstable and there might be security issues at the moment. Therefore,
use caution when loading files from sources you don't trust.

Parameters
----------
obj: object
The object to be saved. Usually a scikit-learn compatible model.

compression: int, default=zipfile.ZIP_STORED
The compression method to use. See :class:`zipfile.ZipFile` for more
information.

.. versionadded:: 0.7

compresslevel: int, default=None
The compression level to use. See :class:`zipfile.ZipFile` for more
information.

.. versionadded:: 0.7
"""
buffer = _save(obj)
buffer = _save(obj, compression=compression, compresslevel=compresslevel)
return buffer.getbuffer().tobytes()


Expand Down
13 changes: 12 additions & 1 deletion skops/io/tests/test_persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from collections import Counter
from functools import partial, wraps
from pathlib import Path
from zipfile import ZipFile
from zipfile import ZIP_DEFLATED, ZipFile

import joblib
import numpy as np
Expand All @@ -20,6 +20,7 @@
from sklearn.decomposition import SparseCoder
from sklearn.exceptions import SkipTestWarning
from sklearn.experimental import enable_halving_search_cv # noqa
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import (
GridSearchCV,
Expand Down Expand Up @@ -1002,3 +1003,13 @@ def test_persist_function(func):
# check that loaded estimator is identical
assert_params_equal(estimator.__dict__, loaded.__dict__)
assert_method_outputs_equal(estimator, loaded, X)


def test_compression_level():
# Test that setting the compression to zlib and specifying a
# compressionlevel reduces the dumped size.
model = TfidfVectorizer().fit([np.__doc__])
dumped_raw = dumps(model)
dumped_compressed = dumps(model, compression=ZIP_DEFLATED, compresslevel=9)
# This reduces the size substantially
assert len(dumped_raw) > 5 * len(dumped_compressed)