Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Make skops.io.dump and load work with TextIOWrapper #234

Merged
merged 5 commits into from
Dec 2, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/changes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ skops Changelog

v0.4
----
- :func:`.io.dump` and :func:`.io.load` now work with :class:`io.TextIOWrapper`,
BenjaminBossan marked this conversation as resolved.
Show resolved Hide resolved
which means you can use them with the ``with open(...) as f: dump(obj, f)``
pattern, like you'd do with ``pickle``. :pr:`234` by `Benjamin Bossan`_.

v0.3
----
Expand Down
14 changes: 9 additions & 5 deletions skops/io/_persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import io
import json
from pathlib import Path
from typing import Any, Sequence
from typing import Any, BinaryIO, Sequence
from zipfile import ZipFile

import skops
Expand Down Expand Up @@ -39,7 +39,7 @@ def _save(obj: Any) -> io.BytesIO:
return buffer


def dump(obj: Any, file: str) -> None:
def dump(obj: Any, file: str | Path | BinaryIO) -> None:
"""Save an object using the skops persistence format.

Skops aims at providing a secure persistence feature that does not rely on
Expand All @@ -58,15 +58,19 @@ def dump(obj: Any, file: str) -> None:
obj: object
The object to be saved. Usually a scikit-learn compatible model.

file: str
file: str, path, or file-like object
The file name. A zip archive will automatically created. As a matter of
convention, we recommend to use the ".skops" file extension, e.g.
``save(model, "my-model.skops")``.

"""
buffer = _save(obj)
with open(file, "wb") as f:
f.write(buffer.getbuffer())

if isinstance(file, (str, Path)):
with open(file, "wb") as f:
f.write(buffer.getbuffer())
else:
file.write(buffer.getbuffer())


def dumps(obj: Any) -> bytes:
Expand Down
15 changes: 15 additions & 0 deletions skops/io/tests/test_persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,6 +968,21 @@ def test_disk_and_memory_are_identical(tmp_path):
assert joblib.hash(loaded_disk) == joblib.hash(loaded_memory)


def test_dump_and_load_with_file_wrapper(tmp_path):
# The idea here is to make it possible to use dump and load with a file
# wrapper, i.e. using 'with open(...)'. This makes it easier to search and
# replace pickle dump and load by skops dump and load.
estimator = LogisticRegression().fit([[0, 1], [2, 3], [4, 5]], [0, 1, 1])
f_name = tmp_path / "estimator.skops"

with open(f_name, "wb") as f:
dump(estimator, f)
with open(f_name, "rb") as f:
loaded = load(f, trusted=True)

assert_params_equal(loaded.__dict__, estimator.__dict__)


@pytest.mark.parametrize(
"obj",
[
Expand Down