diff --git a/src/dvclive/plots/image.py b/src/dvclive/plots/image.py index d7761d05..5cc65362 100644 --- a/src/dvclive/plots/image.py +++ b/src/dvclive/plots/image.py @@ -1,5 +1,7 @@ from pathlib import Path, PurePath +from dvclive.utils import isinstance_without_import + from .base import Data @@ -17,20 +19,25 @@ def output_path(self) -> Path: def could_log(val: object) -> bool: acceptable = { ("numpy", "ndarray"), + ("matplotlib.figure", "Figure"), ("PIL.Image", "Image"), } for cls in type(val).mro(): - if (cls.__module__, cls.__name__) in acceptable: + if any(isinstance_without_import(val, *cls) for cls in acceptable): return True if isinstance(val, (PurePath, str)): return True return False def dump(self, val, **kwargs) -> None: # noqa: ARG002 - if val.__class__.__module__ == "numpy": + if isinstance_without_import(val, "numpy", "ndarray"): from PIL import Image as ImagePIL - pil_image = ImagePIL.fromarray(val) - else: - pil_image = val - pil_image.save(self.output_path) + ImagePIL.fromarray(val).save(self.output_path) + elif isinstance_without_import(val, "matplotlib.figure", "Figure"): + import matplotlib.pyplot as plt + + plt.savefig(self.output_path) + plt.close(val) + elif isinstance_without_import(val, "PIL.Image", "Image"): + val.save(self.output_path) diff --git a/src/dvclive/utils.py b/src/dvclive/utils.py index 53b31a86..05b70502 100644 --- a/src/dvclive/utils.py +++ b/src/dvclive/utils.py @@ -148,3 +148,10 @@ def clean_and_copy_into(src: StrPath, dst: StrPath) -> str: shutil.copy2(src, dst_path) return str(dst_path) + + +def isinstance_without_import(val, module, name): + for cls in type(val).mro(): + if (cls.__module__, cls.__name__) == (module, name): + return True + return False diff --git a/tests/plots/test_image.py b/tests/plots/test_image.py index 8a277dc3..ed52efbc 100644 --- a/tests/plots/test_image.py +++ b/tests/plots/test_image.py @@ -1,3 +1,4 @@ +import matplotlib.pyplot as plt import numpy as np import pytest from PIL import Image @@ -100,3 +101,17 @@ def test_custom_class(tmp_dir): live.log_image("image.png", extended_img) assert (tmp_dir / live.plots_dir / LiveImage.subfolder / "image.png").exists() + + +def test_matplotlib(tmp_dir): + live = Live() + fig, ax = plt.subplots() + ax.plot([1, 2, 3, 4]) + + assert plt.fignum_exists(fig.number) + + live.log_image("image.png", fig) + + assert not plt.fignum_exists(fig.number) + + assert (tmp_dir / live.plots_dir / LiveImage.subfolder / "image.png").exists()