Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add is_sync logic to Metric #339

Merged
merged 31 commits into from
Jul 1, 2021
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
67ec1f7
add is_synced logic to Metric base class
tchaton Jul 1, 2021
109efa7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 1, 2021
db859ff
resolve flake8
tchaton Jul 1, 2021
1db1903
update
tchaton Jul 1, 2021
5a218b5
Merge branch 'fix_ddp' of https://github.com/PyTorchLightning/metrics…
tchaton Jul 1, 2021
c34f788
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 1, 2021
4290cd2
update
tchaton Jul 1, 2021
f9bbc2d
Merge branch 'fix_ddp' of https://github.com/PyTorchLightning/metrics…
tchaton Jul 1, 2021
9466c96
resolve a bug
tchaton Jul 1, 2021
1b6fbee
resolve tests
tchaton Jul 1, 2021
711c74f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 1, 2021
2ce7581
clean API
tchaton Jul 1, 2021
176415c
Merge branch 'fix_ddp' of https://github.com/PyTorchLightning/metrics…
tchaton Jul 1, 2021
74477c9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 1, 2021
ab386e1
clean api
tchaton Jul 1, 2021
366f91c
Merge branch 'fix_ddp' of https://github.com/PyTorchLightning/metrics…
tchaton Jul 1, 2021
e5b53b2
update
tchaton Jul 1, 2021
3b9356c
resolve test
tchaton Jul 1, 2021
54c8e38
Apply suggestions from code review
Borda Jul 1, 2021
3b8a01e
cleanup
tchaton Jul 1, 2021
e034dd9
cleanup
tchaton Jul 1, 2021
8747226
Merge branch 'fix_ddp' of https://github.com/PyTorchLightning/metrics…
tchaton Jul 1, 2021
5fb1843
remove commented code
tchaton Jul 1, 2021
373bb7c
update on comments
tchaton Jul 1, 2021
f51b2e4
add docstring
tchaton Jul 1, 2021
8a652f7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 1, 2021
7fa9283
update on comments
tchaton Jul 1, 2021
0688430
chlog
Borda Jul 1, 2021
bad75aa
Merge branch 'master' into fix_ddp
mergify[bot] Jul 1, 2021
77d3389
update on comments
tchaton Jul 1, 2021
5d33441
add comments
tchaton Jul 1, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 62 additions & 10 deletions tests/bases/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import os
import sys
from copy import deepcopy
from unittest import mock

import pytest
import torch
Expand All @@ -24,6 +23,7 @@
from tests.helpers.testers import DummyMetric, setup_ddp
from torchmetrics import Metric
from torchmetrics.utilities.distributed import gather_all_tensors
from torchmetrics.utilities.exceptions import MisconfigurationException

seed_all(42)

Expand Down Expand Up @@ -138,30 +138,82 @@ def update(self, x):
def compute(self):
return self.x // self.c

def __repr__(self):
return f"DummyCatMetric(x={self.x}, c={self.c})"

metric = DummyCatMetric()
metric.persistent(True)

def verify_metric(metric, i, world_size):
state_dict = metric.state_dict()
exp_sum = i * (i + 1) / 2
assert state_dict["x"] == exp_sum * world_size
assert metric.x == exp_sum * world_size
assert metric.c == (i + 1) * world_size
assert state_dict["c"] == metric.c

steps = 5
for i in range(steps):

if metric.is_synced:

with pytest.raises(MisconfigurationException, match="The Metric shouldn't be synced when performing"):
metric(i)

metric.unsync()

metric(i)
state_dict = metric.state_dict()

exp_sum = i * (i + 1) / 2
assert state_dict["x"] == exp_sum * worldsize
assert metric.x == exp_sum
assert metric.c == (i + 1)
assert state_dict["c"] == metric.c * worldsize
verify_metric(metric, i, 1)

metric.sync()
assert metric.is_synced

with pytest.raises(MisconfigurationException, match="The Metric has already been synced."):
metric.sync()

verify_metric(metric, i, 2)

metric.unsync()
assert not metric.is_synced

with pytest.raises(MisconfigurationException, match="The Metric has already been un-synced."):
metric.unsync()

with metric.sync_context():
assert metric.is_synced
verify_metric(metric, i, 2)

with metric.sync_context(should_unsync=False):
assert metric.is_synced
verify_metric(metric, i, 2)

assert metric.is_synced

metric.unsync()

assert not metric.is_synced

metric.sync()

def reload_state_dict(state_dict, expected_x, expected_c):
metric = DummyCatMetric()
metric.load_state_dict(state_dict)
assert metric.x == expected_x
assert metric.c == expected_c

with mock.patch.dict(os.environ, {"GLOBAL_RANK": str(rank)}):
reload_state_dict(deepcopy(state_dict), 20 if not rank else 0, 10 if not rank else 0)
reload_state_dict(deepcopy(metric.state_dict()), 20, 10)

metric.unsync()
reload_state_dict(deepcopy(metric.state_dict()), 10, 5)

metric.sync()

torch.save(metric.state_dict(), os.path.join(tmpdir, 'weights.pt'))

reload_state_dict(deepcopy(state_dict), 20, 10)
metric.unsync()
with metric.sync_context():
torch.save(metric.state_dict(), os.path.join(tmpdir, 'weights.pt'))
tchaton marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
Expand Down
102 changes: 65 additions & 37 deletions torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import functools
import inspect
import operator
import os
from abc import ABC, abstractmethod
from collections.abc import Sequence
from contextlib import contextmanager
Expand All @@ -28,6 +27,7 @@
from torchmetrics.utilities import apply_to_collection, rank_zero_warn
from torchmetrics.utilities.data import _flatten, dim_zero_cat, dim_zero_mean, dim_zero_sum
from torchmetrics.utilities.distributed import gather_all_tensors
from torchmetrics.utilities.exceptions import MisconfigurationException
from torchmetrics.utilities.imports import _LIGHTNING_AVAILABLE, _compare_version


Expand Down Expand Up @@ -68,7 +68,7 @@ class Metric(nn.Module, ABC):
will be used to perform the allgather.
"""

__jit_ignored_attributes__ = ["is_differentiable"]
__jit_ignored_attributes__ = ['is_differentiable', 'device', 'dtype']

def __init__(
self,
Expand All @@ -90,7 +90,7 @@ def __init__(
self.process_group = process_group
self.dist_sync_fn = dist_sync_fn
self._to_sync = True
self._restore_cache = True
self._should_unsync = True

self._update_signature = inspect.signature(self.update)
self.update: Callable = self._wrap_update(self.update) # type: ignore
Expand All @@ -104,6 +104,10 @@ def __init__(
self._persistent: Dict[str, bool] = {}
self._reductions: Dict[str, Union[str, Callable[[Union[List[Tensor], Tensor]], Tensor], None]] = {}

# state management
self.is_synced = False
self._cache: Optional[Dict[str, Union[List[Tensor], Tensor]]] = None
tchaton marked this conversation as resolved.
Show resolved Hide resolved

def add_state(
self,
name: str,
Expand Down Expand Up @@ -176,13 +180,19 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
Automatically calls ``update()``. Returns the metric value over inputs if ``compute_on_step`` is True.
"""
# add current step
if self.is_synced:
raise MisconfigurationException(
"The Metric shouldn't be synced when performing ``update``. "
"HINT: Did you forget to call ``unsync`` ?."
)

with torch.no_grad():
self.update(*args, **kwargs)

if self.compute_on_step:
self._to_sync = self.dist_sync_on_step
# skip restore cache operation from compute as cache is stored below.
self._restore_cache = False
self._should_unsync = False

# save context before switch
cache = {attr: getattr(self, attr) for attr in self._defaults}
Expand All @@ -195,8 +205,9 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
# restore context
for attr, val in cache.items():
setattr(self, attr, val)
self.is_synced = False

self._restore_cache = True
self._should_unsync = True
self._to_sync = True
self._computed = None

Expand Down Expand Up @@ -245,7 +256,7 @@ def sync(
process_group: Optional[Any] = None,
should_sync: bool = True,
distributed_available: Optional[Callable] = jit_distributed_available,
) -> Dict[str, Tensor]:
) -> None:
"""
Sync function for manually controlling when metrics states should be synced across processes

Expand All @@ -258,27 +269,47 @@ def sync(
only when running in a distributed setting.
distributed_available: Function to determine if we are running inside a distributed setting

Returns:
cache: A dictionary containing the local metric states. The cache will be empty if sync didn't happen.
"""
if self.is_synced and should_sync:
raise MisconfigurationException("The Metric has already been synced.")

is_distributed = distributed_available() if callable(distributed_available) else None

if not should_sync or not is_distributed:
return {}
return

if dist_sync_fn is None:
dist_sync_fn = gather_all_tensors

# cache prior to syncing
cache = {attr: getattr(self, attr) for attr in self._defaults}
self._cache = {attr: getattr(self, attr) for attr in self._defaults}

# sync
self._sync_dist(dist_sync_fn, process_group=process_group)
return cache
self.is_synced = True

def unsync(self, should_unsync: bool = True) -> None:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
if not should_unsync:
return
if self.is_synced:
if self._cache is None:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
raise MisconfigurationException("The internal cache should exist to unsync the Metric.")

# if we synced, restore to cache so that we can continue to accumulate un-synced state
for attr, val in self._cache.items():
setattr(self, attr, val)
self.is_synced = False
self._cache = None
else:
raise MisconfigurationException("The Metric has already been un-synced.")

@contextmanager
def sync_context(
self,
dist_sync_fn: Optional[Callable] = None,
process_group: Optional[Any] = None,
should_sync: bool = True,
restore_cache: bool = True,
should_unsync: bool = True,
distributed_available: Optional[Callable] = jit_distributed_available,
) -> Generator:
"""
Expand All @@ -292,11 +323,11 @@ def sync_context(
default: None (which selects the entire world)
should_sync: Whether to apply to state synchronization. This will have an impact
only when running in a distributed setting.
restore_cache: Whether to restore the cache state so that the metrics can
should_unsync: Whether to restore the cache state so that the metrics can
continue to be accumulated.
distributed_available: Function to determine if we are running inside a distributed setting
"""
cache = self.sync(
self.sync(
dist_sync_fn=dist_sync_fn,
process_group=process_group,
should_sync=should_sync,
Expand All @@ -305,10 +336,13 @@ def sync_context(

yield

if cache and restore_cache:
self.unsync(should_unsync=should_unsync)

if self._cache and should_unsync:
# if we synced, restore to cache so that we can continue to accumulate un-synced state
for attr, val in cache.items():
for attr, val in self._cache.items():
setattr(self, attr, val)
self.is_synced = False

def _wrap_compute(self, compute: Callable) -> Callable:

Expand All @@ -326,7 +360,7 @@ def wrapped_func(*args: Any, **kwargs: Any) -> Any:
return self._computed

with self.sync_context(
dist_sync_fn=self.dist_sync_fn, should_sync=self._to_sync, restore_cache=self._restore_cache
dist_sync_fn=self.dist_sync_fn, should_sync=self._to_sync, should_unsync=self._should_unsync
):
self._computed = compute(*args, **kwargs)

Expand Down Expand Up @@ -363,6 +397,9 @@ def reset(self) -> None:
setattr(self, attr, default.detach().clone().to(current_val.device))
else:
setattr(self, attr, [])
# reset internal states
self._cache = None
self.is_synced = False

def clone(self) -> "Metric":
""" Make a copy of the metric """
Expand Down Expand Up @@ -418,23 +455,18 @@ def state_dict(
) -> Optional[Dict[str, Any]]:
destination = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
# Register metric states to be part of the state_dict
with self.sync_context(dist_sync_fn=self.dist_sync_fn):
for key in self._defaults:
if not self._persistent[key]:
continue
current_val = getattr(self, key)
if not keep_vars:
if isinstance(current_val, Tensor):
current_val = current_val.detach()
elif isinstance(current_val, list):
current_val = [cur_v.detach() if isinstance(cur_v, Tensor) else cur_v for cur_v in current_val]
# the tensors will be synced across processes so deepcopy to drop the references
destination[prefix + key] = deepcopy(current_val) # type: ignore
for key in self._defaults:
if not self._persistent[key]:
continue
current_val = getattr(self, key)
if not keep_vars:
if isinstance(current_val, Tensor):
current_val = current_val.detach()
elif isinstance(current_val, list):
current_val = [cur_v.detach() if isinstance(cur_v, Tensor) else cur_v for cur_v in current_val]
destination[prefix + key] = deepcopy(current_val) # type: ignore
return destination

def _should_load_from_state_dict(self) -> bool:
return os.getenv("GLOBAL_RANK", "0") == "0"

def _load_from_state_dict(
self,
state_dict: dict,
Expand All @@ -447,14 +479,10 @@ def _load_from_state_dict(
) -> None:
""" Loads metric states from state_dict """

# only global rank 0 should be reloading the values present in the ``state_dict``
# as the state contains synced values across all progress_group
for key in self._defaults:
name = prefix + key
if name in state_dict:
value = state_dict.pop(name)
if self._should_load_from_state_dict():
setattr(self, key, value)
setattr(self, key, state_dict.pop(name))
super()._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
)
Expand Down
19 changes: 19 additions & 0 deletions torchmetrics/utilities/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


class MisconfigurationException(Exception):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
"""
Exception used to inform users of mis-use with PyTorch Lightning
Borda marked this conversation as resolved.
Show resolved Hide resolved
"""