Skip to content

Commit

Permalink
Removed cuda memory allocation which was causing CUDA OOM (#103)
Browse files Browse the repository at this point in the history
* Removed cuda memory allocation which was causing CUDA OOM

* changed parameter names and added random generator

* Add default shuffle_seed value

* Changed the shuffle_seed annotation type to just int
  • Loading branch information
karan6181 authored Dec 29, 2022
1 parent 5908579 commit b7dbadd
Show file tree
Hide file tree
Showing 10 changed files with 60 additions and 68 deletions.
61 changes: 18 additions & 43 deletions streaming/base/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,14 @@
from enum import IntEnum
from multiprocessing.shared_memory import SharedMemory
from threading import Thread
from time import sleep, time
from time import sleep
from typing import Any, Dict, Iterator, Optional, Tuple

import numpy as np
import torch
from filelock import FileLock
from numpy.typing import NDArray
from torch import distributed as tdist
from torch.utils.data import IterableDataset

from streaming.base import distributed as dist
from streaming.base.compression import decompress
from streaming.base.format import reader_from_json
from streaming.base.format.base.reader import FileInfo
Expand All @@ -28,6 +25,7 @@
from streaming.base.shared import SharedBarrier
from streaming.base.shuffle import get_shuffle
from streaming.base.storage import download
from streaming.base.util import wait_for_file_to_exist
from streaming.base.world import World

# Time to wait, in seconds.
Expand Down Expand Up @@ -121,8 +119,7 @@ class StreamingDataset(IterableDataset):
an exception. Defaults to ``60``.
validate_hash (str, optional): Optional hash or checksum algorithm to use to validate
shards. Defaults to ``None``.
shuffle_seed (int, optional): Seed for shuffling, or ``None`` for random seed. Defaults to
``None``.
shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``.
num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with resumption.
Defaults to ``None``, which is interpreted as the number of nodes of the initial run.
batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is
Expand All @@ -139,7 +136,7 @@ def __init__(self,
download_retry: int = 2,
download_timeout: float = 60,
validate_hash: Optional[str] = None,
shuffle_seed: Optional[int] = None,
shuffle_seed: int = 9176,
num_canonical_nodes: Optional[int] = None,
batch_size: Optional[int] = None):
self.local = local
Expand All @@ -152,9 +149,11 @@ def __init__(self,
self.download_timeout = download_timeout
self.validate_hash = validate_hash or None

if tdist.is_available() and not tdist.is_initialized() and torch.cuda.is_available() and \
'RANK' in os.environ:
tdist.init_process_group('nccl')
if self.download_retry < 0:
raise ValueError('Parameter ``download_retry`` must be non-negative')
if self.download_timeout < 0:
raise ValueError(
'Parameter ``download_timeout`` (in seconds) must be greater than zero')

# Seed is set below.
world = World()
Expand All @@ -168,7 +167,11 @@ def __init__(self,
filename = self._download_file(basename)
else:
filename = os.path.join(local, self.split, basename) # pyright: ignore
dist.barrier()

# Everyone waits for the file to become populated.
wait_for_file_to_exist(filename, TICK, self.download_timeout,
f'{filename} file took too long to download')

obj = json.load(open(filename))
if obj['version'] != 2:
raise ValueError('Unsupported version')
Expand All @@ -184,27 +187,9 @@ def __init__(self,
self.index = Index(self.shard_sizes)

# Determine and distribute shuffle seed and shm prefix.
if shuffle_seed is None:
shuffle_seed = np.random.randint(1 << 60)
seed_rng = np.random.default_rng(shuffle_seed)
self.shuffle_seed = int(seed_rng.integers(1 << 60))
prefix_int = np.random.randint(1 << 24)
if world.num_ranks > 1:
# Setup for coordinating.
device_prefix = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(f'{device_prefix}:{world.rank_of_node}')
tensor = torch.zeros(1, dtype=torch.int64, device=device)

# Coordinate the shuffle seed across ranks.
if world.is_leader:
tensor[0] = shuffle_seed
dist.broadcast(tensor, 0)
shuffle_seed = int(tensor)

# Add a coordinated random prefix to all shm names for uniqueness.
if world.is_leader:
tensor[0] = prefix_int
dist.broadcast(tensor, 0)
prefix_int = int(tensor)
self.shuffle_seed = shuffle_seed
self._prefix = f'{prefix_int:06x}'

# Set up the epoch counter.
Expand Down Expand Up @@ -352,7 +337,7 @@ def _get_partition(self,
world: World,
epoch: int,
sample_in_epoch: int,
timeout: Optional[float] = 60) -> NDArray[np.int64]:
timeout: float = 60) -> NDArray[np.int64]:
"""Get this worker's partition of this epoch's sample space.
Args:
Expand Down Expand Up @@ -391,17 +376,7 @@ def _get_partition(self,
os.rename(tmp_filename, filename)

# Everyone waits for the file to become populated.
t0 = time()
while True:
sleep(TICK)
if os.path.exists(filename):
sleep(TICK)
break
if timeout is not None:
dt = time() - t0
if timeout < dt:
raise RuntimeError('Partitioning and shuffling took too long, bailing out: ' +
f'{timeout:.3f} < {dt:.3f} sec.')
wait_for_file_to_exist(filename, TICK, timeout, 'Partitioning and shuffling took too long')

# Each worker loads its slice of the sample ID tensor to iterate through.
# Tensor shape: (num nodes, ranks per node, workers per rank, samples per worker).
Expand Down
26 changes: 26 additions & 0 deletions streaming/base/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

"""Utility and helper functions for datasets."""

import os
from time import sleep, time
from typing import List

__all__ = ['get_list_arg']
Expand All @@ -18,3 +20,27 @@ def get_list_arg(text: str) -> List[str]:
List[str]: Splits, if any.
"""
return text.split(',') if text else []


def wait_for_file_to_exist(filename: str, poll_interval: float, timeout: float,
err_msg: str) -> None:
"""Wait for the file to exist till timeout seconds. Raise an Exception after that.
Args:
filename (str): A file name
poll_interval (float): Number of seconds to wait before next polling
timeout (float): Number of seconds to wait for a file to exist before raising an exception
err_msg (str): Error message description for an exception
Raises:
RuntimeError: Raise an Exception if file does not exist after timeout
"""
start_time = time()
while True:
sleep(poll_interval)
if os.path.exists(filename):
sleep(poll_interval)
break
dt = time() - start_time
if dt > timeout:
raise RuntimeError(f'{err_msg}, bailing out: ' + f'{timeout:.3f} < {dt:.3f} sec.')
5 changes: 2 additions & 3 deletions streaming/text/c4.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ class StreamingC4(StreamingDataset):
an exception. Defaults to ``60``.
validate_hash (str, optional): Optional hash or checksum algorithm to use to validate
shards. Defaults to ``None``.
shuffle_seed (int, optional): Seed for shuffling, or ``None`` for random seed. Defaults to
``None``.
shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``.
num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with resumption.
Defaults to ``None``, which is interpreted as the number of nodes of the initial run.
batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is
Expand All @@ -62,7 +61,7 @@ def __init__(self,
download_retry: int = 2,
download_timeout: float = 60,
validate_hash: Optional[str] = None,
shuffle_seed: Optional[int] = None,
shuffle_seed: int = 9176,
num_canonical_nodes: Optional[int] = None,
batch_size: Optional[int] = None):
if group_method not in {'truncate', 'concat'}:
Expand Down
5 changes: 2 additions & 3 deletions streaming/text/enwiki.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ class StreamingEnWiki(StreamingDataset):
an exception. Defaults to ``60``.
validate_hash (str, optional): Optional hash or checksum algorithm to use to validate
shards. Defaults to ``None``.
shuffle_seed (int, optional): Seed for shuffling, or ``None`` for random seed. Defaults to
``None``.
shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``.
num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with resumption.
Defaults to ``None``, which is interpreted as the number of nodes of the initial run.
batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is
Expand All @@ -51,7 +50,7 @@ def __init__(self,
download_retry: int = 2,
download_timeout: float = 60,
validate_hash: Optional[str] = None,
shuffle_seed: Optional[int] = None,
shuffle_seed: int = 9176,
num_canonical_nodes: Optional[int] = None,
batch_size: Optional[int] = None):
super().__init__(local, remote, split, shuffle, predownload, keep_zip, download_retry,
Expand Down
5 changes: 2 additions & 3 deletions streaming/text/pile.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ class StreamingPile(StreamingDataset):
an exception. Defaults to ``60``.
validate_hash (str, optional): Optional hash or checksum algorithm to use to validate
shards. Defaults to ``None``.
shuffle_seed (int, optional): Seed for shuffling, or ``None`` for random seed. Defaults to
``None``.
shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``.
num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with resumption.
Defaults to ``None``, which is interpreted as the number of nodes of the initial run.
batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is
Expand All @@ -62,7 +61,7 @@ def __init__(self,
download_retry: int = 2,
download_timeout: float = 60,
validate_hash: Optional[str] = None,
shuffle_seed: Optional[int] = None,
shuffle_seed: int = 9176,
num_canonical_nodes: Optional[int] = None,
batch_size: Optional[int] = None) -> None:
if group_method not in ['truncate']:
Expand Down
5 changes: 2 additions & 3 deletions streaming/vision/ade20k.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ class StreamingADE20K(StreamingDataset):
an exception. Defaults to ``60``.
validate_hash (str, optional): Optional hash or checksum algorithm to use to validate
shards. Defaults to ``None``.
shuffle_seed (int, optional): Seed for shuffling, or ``None`` for random seed. Defaults to
``None``.
shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``.
num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with resumption.
Defaults to ``None``, which is interpreted as the number of nodes of the initial run.
batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is
Expand All @@ -62,7 +61,7 @@ def __init__(self,
download_retry: int = 2,
download_timeout: float = 60,
validate_hash: Optional[str] = None,
shuffle_seed: Optional[int] = None,
shuffle_seed: int = 9176,
num_canonical_nodes: Optional[int] = None,
batch_size: Optional[int] = None):
super().__init__(local, remote, split, shuffle, predownload, keep_zip, download_retry,
Expand Down
10 changes: 4 additions & 6 deletions streaming/vision/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,7 @@ class StreamingVisionDataset(StreamingDataset, VisionDataset):
an exception. Defaults to ``60``.
validate_hash (str, optional): Optional hash or checksum algorithm to use to validate
shards. Defaults to ``None``.
shuffle_seed (int, optional): Seed for shuffling, or ``None`` for random seed. Defaults to
``None``.
shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``.
num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with resumption.
Defaults to ``None``, which is interpreted as the number of nodes of the initial run.
batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is
Expand All @@ -94,7 +93,7 @@ def __init__(self,
download_retry: int = 2,
download_timeout: float = 60,
validate_hash: Optional[str] = None,
shuffle_seed: Optional[int] = None,
shuffle_seed: int = 9176,
num_canonical_nodes: Optional[int] = None,
batch_size: Optional[int] = None):
super().__init__(local, remote, split, shuffle, predownload, keep_zip, download_retry,
Expand Down Expand Up @@ -153,8 +152,7 @@ class StreamingImageClassDataset(StreamingVisionDataset):
an exception. Defaults to ``60``.
validate_hash (str, optional): Optional hash or checksum algorithm to use to validate
shards. Defaults to ``None``.
shuffle_seed (int, optional): Seed for shuffling, or ``None`` for random seed. Defaults to
``None``.
shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``.
num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with resumption.
Defaults to ``None``, which is interpreted as the number of nodes of the initial run.
batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is
Expand All @@ -173,7 +171,7 @@ def __init__(self,
download_retry: int = 2,
download_timeout: float = 60,
validate_hash: Optional[str] = None,
shuffle_seed: Optional[int] = None,
shuffle_seed: int = 9176,
num_canonical_nodes: Optional[int] = None,
batch_size: Optional[int] = None):
transforms = None
Expand Down
3 changes: 1 addition & 2 deletions streaming/vision/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ class StreamingCIFAR10(StreamingImageClassDataset):
an exception. Defaults to ``60``.
validate_hash (str, optional): Optional hash or checksum algorithm to use to validate
shards. Defaults to ``None``.
shuffle_seed (int, optional): Seed for shuffling, or ``None`` for random seed. Defaults to
``None``.
shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``.
num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with resumption.
Defaults to ``None``, which is interpreted as the number of nodes of the initial run.
batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is
Expand Down
5 changes: 2 additions & 3 deletions streaming/vision/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ class StreamingCOCO(StreamingDataset):
an exception. Defaults to ``60``.
validate_hash (str, optional): Optional hash or checksum algorithm to use to validate
shards. Defaults to ``None``.
shuffle_seed (int, optional): Seed for shuffling, or ``None`` for random seed. Defaults to
``None``.
shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``.
num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with resumption.
Defaults to ``None``, which is interpreted as the number of nodes of the initial run.
batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is
Expand All @@ -56,7 +55,7 @@ def __init__(self,
download_retry: int = 2,
download_timeout: float = 60,
validate_hash: Optional[str] = None,
shuffle_seed: Optional[int] = None,
shuffle_seed: int = 9176,
num_canonical_nodes: Optional[int] = None,
batch_size: Optional[int] = None):
super().__init__(local, remote, split, shuffle, predownload, keep_zip, download_retry,
Expand Down
3 changes: 1 addition & 2 deletions streaming/vision/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ class StreamingImageNet(StreamingImageClassDataset):
an exception. Defaults to ``60``.
validate_hash (str, optional): Optional hash or checksum algorithm to use to validate
shards. Defaults to ``None``.
shuffle_seed (int, optional): Seed for shuffling, or ``None`` for random seed. Defaults to
``None``.
shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``.
num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with resumption.
Defaults to ``None``, which is interpreted as the number of nodes of the initial run.
batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is
Expand Down

0 comments on commit b7dbadd

Please sign in to comment.