Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Add sync_context and sync to nn.Metric #302

Merged
merged 49 commits into from
Jun 21, 2021
Merged

[feat] Add sync_context and sync to nn.Metric #302

merged 49 commits into from
Jun 21, 2021

Conversation

tchaton
Copy link
Contributor

@tchaton tchaton commented Jun 17, 2021

What does this PR do?

This PR add a more generic _apply_sync function to nn.Metric base class.

Fixes: #67

Reasoning:
User might want to perform a reduction but not perform compute.

Current use case:
In Lightning, we are enabling restart in mid-epoch.
To do this, we need to save the synchronised states across process on rank 0.
Therefore, the compute call is just an overhead and should be skipped.

Did you have fun?

Make sure you had fun coding 🙃

@tchaton tchaton added this to the v0.5 milestone Jun 17, 2021
@tchaton tchaton self-assigned this Jun 17, 2021
@codecov
Copy link

codecov bot commented Jun 17, 2021

Codecov Report

Merging #302 (9872ddd) into master (f54ccca) will increase coverage by 0.08%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #302      +/-   ##
==========================================
+ Coverage   96.35%   96.44%   +0.08%     
==========================================
  Files          97       97              
  Lines        3241     3259      +18     
==========================================
+ Hits         3123     3143      +20     
+ Misses        118      116       -2     
Flag Coverage Δ
Linux 76.79% <50.00%> (-0.03%) ⬇️
Windows 76.79% <50.00%> (-0.03%) ⬇️
cpu 96.37% <100.00%> (+0.02%) ⬆️
gpu 96.40% <100.00%> (?)
macOS 96.37% <100.00%> (+0.02%) ⬆️
pytest 96.44% <100.00%> (+0.08%) ⬆️
python3.6 95.41% <100.00%> (+0.02%) ⬆️
python3.8 96.28% <100.00%> (-0.08%) ⬇️
python3.9 ?
torch1.3.1 95.41% <100.00%> (+0.02%) ⬆️
torch1.4.0 ?
torch1.9.0 96.28% <100.00%> (+0.02%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
torchmetrics/metric.py 95.51% <100.00%> (+0.27%) ⬆️
torchmetrics/functional/regression/spearman.py 97.77% <0.00%> (+4.44%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update f54ccca...9872ddd. Read the comment docs.

@pep8speaks
Copy link

pep8speaks commented Jun 17, 2021

Hello @tchaton! Thanks for updating this PR.

There are currently no PEP 8 issues detected in this Pull Request. Cheers! 🍻

Comment last updated at 2021-06-21 09:41:26 UTC

@SkafteNicki
Copy link
Member

Hi @tchaton,
Seems related to #67 so maybe this should be a public method?

torchmetrics/metric.py Outdated Show resolved Hide resolved
torchmetrics/metric.py Outdated Show resolved Hide resolved
torchmetrics/metric.py Outdated Show resolved Hide resolved
tests/bases/test_ddp.py Outdated Show resolved Hide resolved
torchmetrics/metric.py Outdated Show resolved Hide resolved
torchmetrics/metric.py Outdated Show resolved Hide resolved
torchmetrics/metric.py Outdated Show resolved Hide resolved
torchmetrics/metric.py Outdated Show resolved Hide resolved
@Borda
Copy link
Member

Borda commented Jun 18, 2021

seems constantly failing on PT 1.6

E       Exception: 
E       
E       -- Process 0 terminated with the following error:
E       Traceback (most recent call last):
E         File "/usr/share/miniconda3/envs/test/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 20, in _wrap
E           fn(i, *args)
E         File "/home/runner/work/metrics/metrics/tests/bases/test_ddp.py", line 146, in _test_state_dict_is_synced
E           metric(i)
E         File "/usr/share/miniconda3/envs/test/lib/python3.7/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
E           result = self.forward(*input, **kwargs)
E         File "/home/runner/work/metrics/metrics/torchmetrics/metric.py", line 189, in forward
E           self._forward_cache = self.compute()
E         File "/home/runner/work/metrics/metrics/torchmetrics/metric.py", line 326, in wrapped_func
E           self._computed = compute(*args, **kwargs)
E         File "/home/runner/work/metrics/metrics/tests/bases/test_ddp.py", line 139, in compute
E           return self.x / self.c
E       RuntimeError: Integer division of tensors using div or / is no longer supported, and in a future release div will perform true division as in Python 3. Use true_divide or floor_divide (// in Python) instead.

@Borda Borda enabled auto-merge (squash) June 18, 2021 13:33
@tchaton
Copy link
Contributor Author

tchaton commented Jun 18, 2021

seems constantly failing on PT 1.6

E       Exception: 
E       
E       -- Process 0 terminated with the following error:
E       Traceback (most recent call last):
E         File "/usr/share/miniconda3/envs/test/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 20, in _wrap
E           fn(i, *args)
E         File "/home/runner/work/metrics/metrics/tests/bases/test_ddp.py", line 146, in _test_state_dict_is_synced
E           metric(i)
E         File "/usr/share/miniconda3/envs/test/lib/python3.7/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
E           result = self.forward(*input, **kwargs)
E         File "/home/runner/work/metrics/metrics/torchmetrics/metric.py", line 189, in forward
E           self._forward_cache = self.compute()
E         File "/home/runner/work/metrics/metrics/torchmetrics/metric.py", line 326, in wrapped_func
E           self._computed = compute(*args, **kwargs)
E         File "/home/runner/work/metrics/metrics/tests/bases/test_ddp.py", line 139, in compute
E           return self.x / self.c
E       RuntimeError: Integer division of tensors using div or / is no longer supported, and in a future release div will perform true division as in Python 3. Use true_divide or floor_divide (// in Python) instead.

Thanks. Resolved.

@Borda
Copy link
Member

Borda commented Jun 18, 2021

@tchaton any reason/justification why the tests take more than an extra 5min?

@tchaton
Copy link
Contributor Author

tchaton commented Jun 18, 2021

@tchaton any reason/justification why the tests take more than an extra 5min?

Great question. I am investigating. @SkafteNicki @justusschock Any idea ?

Best,
T.C

tests/bases/test_ddp.py Outdated Show resolved Hide resolved
torchmetrics/metric.py Outdated Show resolved Hide resolved
self,
dist_sync_fn: Optional[Callable] = None,
process_group: Optional[Any] = None,
should_sync: bool = True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused to see should_sync=True|False.

If you set False, this method does nothing, so it's the same as not calling sync in the first place!
Then, if you set True but dist is not available, it will do nothing so basically it does not what the user wants.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should_sync means should_sync if possible :) Modified the docstring to reflect this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It means now that there are two arguments overlapping: dist_sync_fn and should_sync

You can do this: should_sync=False and dist_sync_fn=mean

what willl happen now? will it sync or not?
@PyTorchLightning/core-metrics be aware of these cases

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I agree. I wonder if the main usage of should_sync is just in sync_context and maybe we should just decide there if syncing is needed or not? Doing an if with context manager is a bit harder and might justify a flag, but for a function, it should be easy for people to just avoid calling it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maximsch2 your argument is to keep the flag for the context manager but remove it from this function, correct?

I think that would be fine.

torchmetrics/metric.py Show resolved Hide resolved
tests/bases/test_ddp.py Show resolved Hide resolved
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
@tchaton tchaton changed the title [feat] Add _apply_sync to nn.Metric [feat] Add sync_context and sync to nn.Metric Jun 21, 2021
@Borda Borda merged commit fc3333b into master Jun 21, 2021
@Borda Borda deleted the apply_sync_fn branch June 21, 2021 10:27
@maximsch2
Copy link
Contributor

Late to the party here, but I think we can also imagine the future where models are huge and sharded (with FSDP) and metric states are similarly sharded. We are getting away with synchronizing everything on rank0 for now but long-term we might need to have metrics that wont' be able to do that

Copy link
Contributor

@ananthsub ananthsub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tchaton I had a different idea when proposing sync: the Metric subclass should provide the implementation, not the base. The metric states are updated in-place instead of returned to the caller. So one could chain together calls to update and sync before finally calling compute to get the state.

dist_sync_fn: Optional[Callable] = None,
process_group: Optional[Any] = None,
should_sync: bool = True,
distributed_available: Optional[Callable] = distributed_available,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n00b question: why does distributed available need to be an argument?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because torchmetrics does not know about any distributed platforms other than CUDA

For example TPU, IPUs...

Comment on lines +303 to +307
if cache and restore_cache:
# if we synced, restore to cache so that we can continue to accumulate un-synced state
for attr, val in cache.items():
setattr(self, attr, val)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the use case for this? If we sync, we should assume all metrics are operating off the synced state and not accumulate local changes, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was here already, just moved.

Added in dd1e744
cc: @SkafteNicki

)

for attr, reduction_fn in self._reductions.items():
# pre-processing ops (stack or flatten for inputs)
if isinstance(output_dict[attr][0], Tensor):
if isinstance(output_dict[attr], Sequence) and isinstance(output_dict[attr][0], Tensor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this check is not safe. we're seeing errors as a result.

if isinstance(output_dict[attr], Sequence) and isinstance(output_dict[attr][0], Tensor):

IndexError: list index out of range

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed in #311

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Offer a dedicated sync() interface on the base Metric class
9 participants