Skip to content

Commit

Permalink
sync random seed in DistributedSamplers (open-mmlab#1257)
Browse files Browse the repository at this point in the history
  • Loading branch information
ly015 committed Apr 2, 2022
1 parent 7f3d276 commit 36eb886
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 3 deletions.
4 changes: 2 additions & 2 deletions mmpose/core/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .dist_utils import allreduce_grads
from .dist_utils import allreduce_grads, sync_random_seed
from .regularizations import WeightNormClipHook

__all__ = ['allreduce_grads', 'WeightNormClipHook']
__all__ = ['allreduce_grads', 'WeightNormClipHook', 'sync_random_seed']
39 changes: 39 additions & 0 deletions mmpose/core/utils/dist_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from collections import OrderedDict

import numpy as np
import torch
import torch.distributed as dist
from mmcv.runner import get_dist_info
from torch._utils import (_flatten_dense_tensors, _take_tensors,
_unflatten_dense_tensors)

Expand Down Expand Up @@ -49,3 +52,39 @@ def allreduce_grads(params, coalesce=True, bucket_size_mb=-1):
else:
for tensor in grads:
dist.all_reduce(tensor.div_(world_size))


def sync_random_seed(seed=None, device='cuda'):
"""Make sure different ranks share the same seed.
All workers must call
this function, otherwise it will deadlock. This method is generally used in
`DistributedSampler`, because the seed should be identical across all
processes in the distributed group.
In distributed sampling, different ranks should sample non-overlapped
data in the dataset. Therefore, this function is used to make sure that
each rank shuffles the data indices in the same order based
on the same seed. Then different ranks could use different indices
to select non-overlapped data from the same data list.
Args:
seed (int, Optional): The seed. Default to None.
device (str): The device where the seed will be put on.
Default to 'cuda'.
Returns:
int: Seed to be used.
"""
if seed is None:
seed = np.random.randint(2**31)
assert isinstance(seed, int)

rank, world_size = get_dist_info()

if world_size == 1:
return seed

if rank == 0:
random_num = torch.tensor(seed, dtype=torch.int32, device=device)
else:
random_num = torch.tensor(0, dtype=torch.int32, device=device)
dist.broadcast(random_num, src=0)
return random_num.item()
9 changes: 8 additions & 1 deletion mmpose/datasets/samplers/distributed_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import torch
from torch.utils.data import DistributedSampler as _DistributedSampler

from mmpose.core import sync_random_seed


class DistributedSampler(_DistributedSampler):
"""DistributedSampler inheriting from
Expand All @@ -20,7 +22,12 @@ def __init__(self,
super().__init__(
dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
# for the compatibility from PyTorch 1.3+
self.seed = seed if seed is not None else 0
# In distributed sampling, different ranks should sample non-overlapped
# data in the dataset. Therefore, this function is used to make sure
# that each rank shuffles the data indices in the same order based
# on the same seed. Then different ranks could use different indices
# to select non-overlapped data from the same data list.
self.seed = sync_random_seed(seed) if seed is not None else 0

def __iter__(self):
"""Deterministically shuffle based on epoch."""
Expand Down

0 comments on commit 36eb886

Please sign in to comment.