Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

log_image: Support matplotlib.figure.Figure as input. #658

Merged
merged 1 commit into from
Aug 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions src/dvclive/plots/image.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from pathlib import Path, PurePath

from dvclive.utils import isinstance_without_import

from .base import Data


Expand All @@ -17,20 +19,25 @@ def output_path(self) -> Path:
def could_log(val: object) -> bool:
acceptable = {
("numpy", "ndarray"),
("matplotlib.figure", "Figure"),
("PIL.Image", "Image"),
}
for cls in type(val).mro():
if (cls.__module__, cls.__name__) in acceptable:
if any(isinstance_without_import(val, *cls) for cls in acceptable):
return True
if isinstance(val, (PurePath, str)):
return True
return False

def dump(self, val, **kwargs) -> None: # noqa: ARG002
if val.__class__.__module__ == "numpy":
if isinstance_without_import(val, "numpy", "ndarray"):
from PIL import Image as ImagePIL

pil_image = ImagePIL.fromarray(val)
else:
pil_image = val
pil_image.save(self.output_path)
ImagePIL.fromarray(val).save(self.output_path)
elif isinstance_without_import(val, "matplotlib.figure", "Figure"):
import matplotlib.pyplot as plt

plt.savefig(self.output_path)
plt.close(val)
elif isinstance_without_import(val, "PIL.Image", "Image"):
val.save(self.output_path)
7 changes: 7 additions & 0 deletions src/dvclive/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,10 @@ def clean_and_copy_into(src: StrPath, dst: StrPath) -> str:
shutil.copy2(src, dst_path)

return str(dst_path)


def isinstance_without_import(val, module, name):
for cls in type(val).mro():
if (cls.__module__, cls.__name__) == (module, name):
return True
return False
15 changes: 15 additions & 0 deletions tests/plots/test_image.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import matplotlib.pyplot as plt
import numpy as np
import pytest
from PIL import Image
Expand Down Expand Up @@ -100,3 +101,17 @@ def test_custom_class(tmp_dir):
live.log_image("image.png", extended_img)

assert (tmp_dir / live.plots_dir / LiveImage.subfolder / "image.png").exists()


def test_matplotlib(tmp_dir):
live = Live()
fig, ax = plt.subplots()
ax.plot([1, 2, 3, 4])

assert plt.fignum_exists(fig.number)

live.log_image("image.png", fig)

assert not plt.fignum_exists(fig.number)

assert (tmp_dir / live.plots_dir / LiveImage.subfolder / "image.png").exists()