Skip to content

Commit

Permalink
Bugfix: Removes buffers from state dict (#1728)
Browse files Browse the repository at this point in the history
* fixes
* changelog
  • Loading branch information
SkafteNicki authored Apr 26, 2023
1 parent a970ba6 commit c3628b0
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 7 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed `max_det_threshold` in MAP detection ([#1712](https://github.com/Lightning-AI/torchmetrics/pull/1712))


- Fixed states being saved in metrics that use `register_buffer` ([#1728](https://github.com/Lightning-AI/torchmetrics/pull/1728))


## [0.11.4] - 2023-03-10

### Fixed
Expand Down
6 changes: 3 additions & 3 deletions src/torchmetrics/classification/precision_recall_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def __init__(
self.add_state("preds", default=[], dist_reduce_fx="cat")
self.add_state("target", default=[], dist_reduce_fx="cat")
else:
self.register_buffer("thresholds", thresholds)
self.register_buffer("thresholds", thresholds, persistent=False)
self.add_state(
"confmat", default=torch.zeros(len(thresholds), 2, 2, dtype=torch.long), dist_reduce_fx="sum"
)
Expand Down Expand Up @@ -311,7 +311,7 @@ def __init__(
self.add_state("preds", default=[], dist_reduce_fx="cat")
self.add_state("target", default=[], dist_reduce_fx="cat")
else:
self.register_buffer("thresholds", thresholds)
self.register_buffer("thresholds", thresholds, persistent=False)
self.add_state(
"confmat",
default=torch.zeros(len(thresholds), num_classes, 2, 2, dtype=torch.long),
Expand Down Expand Up @@ -490,7 +490,7 @@ def __init__(
self.add_state("preds", default=[], dist_reduce_fx="cat")
self.add_state("target", default=[], dist_reduce_fx="cat")
else:
self.register_buffer("thresholds", thresholds)
self.register_buffer("thresholds", thresholds, persistent=False)
self.add_state(
"confmat",
default=torch.zeros(len(thresholds), num_labels, 2, 2, dtype=torch.long),
Expand Down
4 changes: 2 additions & 2 deletions src/torchmetrics/functional/image/lpips.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,8 @@ class ScalingLayer(nn.Module):

def __init__(self) -> None:
super().__init__()
self.register_buffer("shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None])
self.register_buffer("scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None])
self.register_buffer("shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None], persistent=False)
self.register_buffer("scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None], persistent=False)

def forward(self, inp: Tensor) -> Tensor:
"""Process input."""
Expand Down
4 changes: 2 additions & 2 deletions src/torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,12 +1024,12 @@ def __init__(
self.op = operator

if isinstance(metric_a, Tensor):
self.register_buffer("metric_a", metric_a)
self.register_buffer("metric_a", metric_a, persistent=False)
else:
self.metric_a = metric_a

if isinstance(metric_b, Tensor):
self.register_buffer("metric_b", metric_b)
self.register_buffer("metric_b", metric_b, persistent=False)
else:
self.metric_b = metric_b

Expand Down
15 changes: 15 additions & 0 deletions tests/unittests/classification/test_precision_recall_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,3 +416,18 @@ def test_valid_input_thresholds(metric, thresholds):
with pytest.warns(None) as record:
metric(thresholds=thresholds)
assert len(record) == 0


@pytest.mark.parametrize(
"metric",
[
BinaryPrecisionRecallCurve,
partial(MulticlassPrecisionRecallCurve, num_classes=NUM_CLASSES),
partial(MultilabelPrecisionRecallCurve, num_labels=NUM_CLASSES),
],
)
@pytest.mark.parametrize("thresholds", [None, 100, [0.3, 0.5, 0.7, 0.9], torch.linspace(0, 1, 10)])
def test_empty_state_dict(metric, thresholds):
"""Test that metric have an empty state dict."""
m = metric(thresholds=thresholds)
assert m.state_dict() == {}, "Metric state dict should be empty."

0 comments on commit c3628b0

Please sign in to comment.