From d42fd04a68e81cfa1e23bb05ccb1621b0ed6e0e4 Mon Sep 17 00:00:00 2001 From: Piotr Picheta Date: Thu, 18 Apr 2024 15:56:05 +0200 Subject: [PATCH] Fixes #812 --- src/dvclive/live.py | 2 +- tests/plots/test_sklearn.py | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/dvclive/live.py b/src/dvclive/live.py index 42e79e96..1bdeaae5 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -664,7 +664,7 @@ def log_sklearn_plot( raise InvalidPlotTypeError(name) sklearn_kwargs = { - k: v for k, v in kwargs.items() if k not in plot_config or k != "normalized" + k: v for k, v in kwargs.items() if k not in plot_config or k == "normalized" } plot.step = self.step plot.dump(val, **sklearn_kwargs) diff --git a/tests/plots/test_sklearn.py b/tests/plots/test_sklearn.py index 848765e7..1b7cb690 100644 --- a/tests/plots/test_sklearn.py +++ b/tests/plots/test_sklearn.py @@ -162,7 +162,7 @@ def test_custom_title(tmp_dir, y_true_y_pred_y_score): live = Live() out = tmp_dir / live.plots_dir / SKLearnPlot.subfolder - y_true, y_pred, _ = y_true_y_pred_y_score + y_true, y_pred, y_score = y_true_y_pred_y_score live.log_sklearn_plot( "confusion_matrix", @@ -174,8 +174,13 @@ def test_custom_title(tmp_dir, y_true_y_pred_y_score): live.log_sklearn_plot( "confusion_matrix", y_true, y_pred, name="val/cm", title="Val Confusion Matrix" ) + live.log_sklearn_plot( + "precision_recall", y_true, y_score, name="val/prc", title="Val Precision Recall" + ) assert (out / "train" / "cm.json").exists() assert (out / "val" / "cm.json").exists() + assert (out / "val" / "prc.json").exists() assert live._plots["train/cm"].plot_config["title"] == "Train Confusion Matrix" assert live._plots["val/cm"].plot_config["title"] == "Val Confusion Matrix" + assert live._plots["val/prc"].plot_config["title"] == "Val Precision Recall"