Skip to content

Commit

Permalink
human-readable suffixes for size_limit and epoch_size (#333)
Browse files Browse the repository at this point in the history
* human-readable suffixes for size_limit and epoch_size

* fixed default value bug: human-readable suffixes

* STR 34: added tests, fixed nits for human readable size args

* STR34: deleted unneeded comments, human readable suffixes

* STR34: fixed docstring human readable suffixes

* STR34: fixed pre-commit hooks for human readable suffixes

* STR34: pinning old identify package. human readable suffixes

* STR34: pinning old identify package. human readable suffixes

* fixed docstring for human-readable suffix explanation
  • Loading branch information
snarayan21 authored Jul 24, 2023
1 parent e2d0431 commit 6b0d045
Show file tree
Hide file tree
Showing 14 changed files with 186 additions and 52 deletions.
2 changes: 1 addition & 1 deletion docs/source/fundamentals/dataset_conversion_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ _encodings['int32'] = Int32

4. A `hashes` algorithm name to verify data integrity. Check out the [hashing](hashing.md) document for additional details.

5. A shard `size_limit` in bytes for each shard file, after which point to start a new shard. Shard file size depends on the dataset size, but generally, too small of a shard size creates a ton of shard files and heavy network overheads, and too large of a shard size creates fewer shard files, but the training start time would increase since it has to wait for a shard file to get downloaded locally. Based on our intuition, the shard file size of 64Mb, and 128Mb play a balanced role.
5. A shard `size_limit` in bytes for each shard file, after which point to start a new shard. Shard file size depends on the dataset size, but generally, too small of a shard size creates a ton of shard files and heavy network overheads, and too large of a shard size creates fewer shard files, but the training start time would increase since it has to wait for a shard file to get downloaded locally. Based on our intuition, the shard file size of 64Mb, and 128Mb play a balanced role. This parameter is a number of bytes, either directly as an `int` or a human-readable suffix (ex: `1024` or `"1kb"`)

6. A `keep_local` parameter if you would like to keep the shard files locally after it has been uploaded to a remote cloud location by MDSWriter.

Expand Down
2 changes: 1 addition & 1 deletion docs/source/getting_started/user_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ hashes = ['sha1']
3. Provide a shard size limit, after which point to start a new shard.
<!--pytest-codeblocks:cont-->
```python
# Number act as a byte, e.g., 1024 bytes
# Number act as a byte, e.g., 1024 bytes. A string abbreviation (ex: "1024b" or "1kb") is also acceptable
limit = 1024
```

Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
'fastapi==0.100.0',
'pydantic==1.10.11',
'uvicorn==0.23.0',
'identify==2.5.25',
]

extra_deps['docs'] = [
Expand Down
20 changes: 14 additions & 6 deletions streaming/base/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from streaming.base.shuffle import get_shuffle
from streaming.base.spanner import Spanner
from streaming.base.stream import Stream
from streaming.base.util import bytes_to_int
from streaming.base.util import bytes_to_int, number_abbrev_to_int
from streaming.base.world import World

# An arbitrary time in the future, used for cold shard eviction.
Expand Down Expand Up @@ -216,10 +216,11 @@ class StreamingDataset(Array, IterableDataset):
keep_zip (bool): Whether to keep or delete the compressed form when decompressing
downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to
``False``.
epoch_size (int, optional): Number of samples to draw per epoch balanced across all
epoch_size (Union[int, str], optional): Number of samples to draw per epoch balanced across all
streams. If ``None``, takes its value from the total number of underlying samples.
Provide this field if you are weighting streams relatively to target a larger or
smaller epoch size. Defaults to ``None``.
smaller epoch size. Defaults to ``None``. Can also take in human-readable number
abbreviations (e.g., ``"100k"``, ``"64M"``, ``"77b"``, and so on). Defaults to ``None``.
predownload (int, optional): Target number of samples ahead to download the shards per
number of workers provided in a dataloader while iterating. If ``None``, its value
gets derived using batch size and number of canonical nodes
Expand Down Expand Up @@ -261,7 +262,7 @@ def __init__(self,
download_timeout: float = 60,
validate_hash: Optional[str] = None,
keep_zip: bool = False,
epoch_size: Optional[int] = None,
epoch_size: Optional[Union[int, str]] = None,
predownload: Optional[int] = None,
cache_limit: Optional[Union[int, str]] = None,
partition_algo: str = 'orig',
Expand Down Expand Up @@ -367,10 +368,17 @@ def __init__(self,
self.sample_offset_per_shard = self.samples_per_shard.cumsum() - self.samples_per_shard
self.spanner = Spanner(self.samples_per_shard)

# Convert epoch size from string to int, if needed. Cannot be negative.
epoch_size_value = None
if epoch_size:
epoch_size_value = number_abbrev_to_int(epoch_size)
if epoch_size_value < 0:
raise ValueError(f'Epoch size cannot be negative. Received {epoch_size_value}.')

# Now that we know the number of underlying samples of each stream, derive each stream's
# true proportion/repeat/choose, as well as the total epoch size.
self.epoch_size = Stream.apply_weights(self.streams, self.samples_per_stream, epoch_size,
self.shuffle_seed)
self.epoch_size = Stream.apply_weights(self.streams, self.samples_per_stream,
epoch_size_value, self.shuffle_seed)

# Length (__len__) is the resampled epoch size divided over the number of devices.
self.length = ceil(self.epoch_size / world.num_ranks)
Expand Down
27 changes: 19 additions & 8 deletions streaming/base/format/base/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Iterator, List, Optional, Set
from typing import Any, Dict, Iterator, List, Optional, Set, Union

from streaming.base.array import Array
from streaming.base.util import bytes_to_int

__all__ = ['FileInfo', 'Reader', 'JointReader', 'SplitReader']

Expand Down Expand Up @@ -36,8 +37,10 @@ class Reader(Array, ABC):
compression (str, optional): Optional compression or compression:level.
hashes (List[str]): Optional list of hash algorithms to apply to shard files.
samples (int): Number of samples in this shard.
size_limit (int, optional): Optional shard size limit, after which point to start a new
shard. If None, puts everything in one shard.
size_limit (Union[int, str], optional): Optional shard size limit, after which point to start a new
shard. If None, puts everything in one shard. Can specify bytes
human-readable format as well, for example ``"100kb"`` for 100 kilobyte
(100*1024) and so on.
"""

def __init__(
Expand All @@ -47,8 +50,16 @@ def __init__(
compression: Optional[str],
hashes: List[str],
samples: int,
size_limit: Optional[int],
size_limit: Optional[Union[int, str]],
) -> None:

if size_limit:
if (isinstance(size_limit, str)):
size_limit = bytes_to_int(size_limit)
if size_limit < 0:
raise ValueError(f'`size_limit` must be greater than zero, instead, ' +
f'found as {size_limit}.')

self.dirname = dirname
self.split = split or ''
self.compression = compression
Expand Down Expand Up @@ -277,7 +288,7 @@ class JointReader(Reader):
hashes (List[str]): Optional list of hash algorithms to apply to shard files.
raw_data (FileInfo): Uncompressed data file info.
samples (int): Number of samples in this shard.
size_limit (int, optional): Optional shard size limit, after which point to start a new
size_limit (Union[int, str], optional): Optional shard size limit, after which point to start a new
shard. If None, puts everything in one shard.
zip_data (FileInfo, optional): Compressed data file info.
"""
Expand All @@ -290,7 +301,7 @@ def __init__(
hashes: List[str],
raw_data: FileInfo,
samples: int,
size_limit: Optional[int],
size_limit: Optional[Union[int, str]],
zip_data: Optional[FileInfo],
) -> None:
super().__init__(dirname, split, compression, hashes, samples, size_limit)
Expand All @@ -310,7 +321,7 @@ class SplitReader(Reader):
raw_data (FileInfo): Uncompressed data file info.
raw_meta (FileInfo): Uncompressed meta file info.
samples (int): Number of samples in this shard.
size_limit (int, optional): Optional shard size limit, after which point to start a new
size_limit (Union[int, str], optional): Optional shard size limit, after which point to start a new
shard. If None, puts everything in one shard.
zip_data (FileInfo, optional): Compressed data file info.
zip_meta (FileInfo, optional): Compressed meta file info.
Expand All @@ -325,7 +336,7 @@ def __init__(
raw_data: FileInfo,
raw_meta: FileInfo,
samples: int,
size_limit: Optional[int],
size_limit: Optional[Union[int, str]],
zip_data: Optional[FileInfo],
zip_meta: Optional[FileInfo],
) -> None:
Expand Down
23 changes: 13 additions & 10 deletions streaming/base/format/base/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from streaming.base.format.index import get_index_basename
from streaming.base.hashing import get_hash, is_hash
from streaming.base.storage.upload import CloudUploader
from streaming.base.util import bytes_to_int

__all__ = ['JointWriter', 'SplitWriter']

Expand All @@ -44,8 +45,10 @@ class Writer(ABC):
``None``.
hashes (List[str], optional): Optional list of hash algorithms to apply to shard files.
Defaults to ``None``.
size_limit (int, optional): Optional shard size limit, after which point to start a new
shard. If ``None``, puts everything in one shard. Defaults to ``1 << 26``.
size_limit (Union[int, str], optional): Optional shard size limit, after which point to start a new
shard. If ``None``, puts everything in one shard. Can specify bytes
human-readable format as well, for example ``"100kb"`` for 100 kilobyte
(100*1024) and so on. Defaults to ``1 << 26``.
extra_bytes_per_shard (int): Extra bytes per serialized shard (for computing shard size
while writing). Defaults to ``0``.
extra_bytes_per_sample (int): Extra bytes per serialized sample (for computing shard size
Expand All @@ -67,7 +70,7 @@ def __init__(self,
keep_local: bool = False,
compression: Optional[str] = None,
hashes: Optional[List[str]] = None,
size_limit: Optional[int] = 1 << 26,
size_limit: Optional[Union[int, str]] = 1 << 26,
extra_bytes_per_shard: int = 0,
extra_bytes_per_sample: int = 0,
**kwargs: Any) -> None:
Expand All @@ -84,17 +87,17 @@ def __init__(self,
if not is_hash(algo):
raise ValueError(f'Invalid hash: {algo}.')

size_limit_value = None
if size_limit:
if size_limit < 0:
size_limit_value = bytes_to_int(size_limit)
if size_limit_value < 0:
raise ValueError(f'`size_limit` must be greater than zero, instead, ' +
f'found as {size_limit}.')
else:
size_limit = None
f'found as {size_limit_value}.')

self.keep_local = keep_local
self.compression = compression
self.hashes = hashes
self.size_limit = size_limit
self.size_limit = size_limit_value
self.extra_bytes_per_shard = extra_bytes_per_shard
self.extra_bytes_per_sample = extra_bytes_per_sample
self.new_samples: List[bytes]
Expand Down Expand Up @@ -338,7 +341,7 @@ def __init__(self,
keep_local: bool = False,
compression: Optional[str] = None,
hashes: Optional[List[str]] = None,
size_limit: Optional[int] = 1 << 26,
size_limit: Optional[Union[int, str]] = 1 << 26,
extra_bytes_per_shard: int = 0,
extra_bytes_per_sample: int = 0,
**kwargs: Any) -> None:
Expand Down Expand Up @@ -419,7 +422,7 @@ def __init__(self,
keep_local: bool = False,
compression: Optional[str] = None,
hashes: Optional[List[str]] = None,
size_limit: Optional[int] = 1 << 26,
size_limit: Optional[Union[int, str]] = 1 << 26,
**kwargs: Any) -> None:
super().__init__(out=out,
keep_local=keep_local,
Expand Down
8 changes: 5 additions & 3 deletions streaming/base/format/json/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import json
import os
from copy import deepcopy
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

import numpy as np
from typing_extensions import Self
Expand All @@ -31,7 +31,9 @@ class JSONReader(SplitReader):
raw_meta (FileInfo): Uncompressed meta file info.
samples (int): Number of samples in this shard.
size_limit (int, optional): Optional shard size limit, after which point to start a new
shard. If None, puts everything in one shard.
shard. If None, puts everything in one shard. Can specify bytes
human-readable format as well, for example ``"100kb"`` for 100 kilobyte
(100*1024) and so on.
zip_data (FileInfo, optional): Compressed data file info.
zip_meta (FileInfo, optional): Compressed meta file info.
"""
Expand All @@ -47,7 +49,7 @@ def __init__(
raw_data: FileInfo,
raw_meta: FileInfo,
samples: int,
size_limit: Optional[int],
size_limit: Optional[Union[int, str]],
zip_data: Optional[FileInfo],
zip_meta: Optional[FileInfo],
) -> None:
Expand Down
8 changes: 5 additions & 3 deletions streaming/base/format/json/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@ class JSONWriter(SplitWriter):
``None``.
hashes (List[str], optional): Optional list of hash algorithms to apply to shard files.
Defaults to ``None``.
size_limit (int, optional): Optional shard size limit, after which point to start a new
shard. If None, puts everything in one shard. Defaults to ``None``.
size_limit (Union[int, str], optional): Optional shard size limit, after which point to start a new
shard. If None, puts everything in one shard. Can specify bytes
human-readable format as well, for example ``"100kb"`` for 100 kilobyte
(100*1024) and so on. Defaults to ``1 << 26``.
**kwargs (Any): Additional settings for the Writer.
progress_bar (bool): Display TQDM progress bars for uploading output dataset files to
Expand All @@ -55,7 +57,7 @@ def __init__(self,
keep_local: bool = False,
compression: Optional[str] = None,
hashes: Optional[List[str]] = None,
size_limit: Optional[int] = 1 << 26,
size_limit: Optional[Union[int, str]] = 1 << 26,
**kwargs: Any) -> None:
super().__init__(out=out,
keep_local=keep_local,
Expand Down
10 changes: 6 additions & 4 deletions streaming/base/format/mds/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import os
from copy import deepcopy
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

import numpy as np
from typing_extensions import Self
Expand All @@ -29,8 +29,10 @@ class MDSReader(JointReader):
hashes (List[str]): Optional list of hash algorithms to apply to shard files.
raw_data (FileInfo): Uncompressed data file info.
samples (int): Number of samples in this shard.
size_limit (int, optional): Optional shard size limit, after which point to start a new
shard. If None, puts everything in one shard.
size_limit (Union[int, str], optional): Optional shard size limit, after which point to start a new
shard. If None, puts everything in one shard. Can specify bytes
human-readable format as well, for example ``"100kb"`` for 100 kilobyte
(100*1024) and so on.
zip_data (FileInfo, optional): Compressed data file info.
"""

Expand All @@ -45,7 +47,7 @@ def __init__(
hashes: List[str],
raw_data: FileInfo,
samples: int,
size_limit: Optional[int],
size_limit: Optional[Union[int, str]],
zip_data: Optional[FileInfo],
) -> None:
super().__init__(dirname, split, compression, hashes, raw_data, samples, size_limit,
Expand Down
8 changes: 5 additions & 3 deletions streaming/base/format/mds/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@ class MDSWriter(JointWriter):
``None``.
hashes (List[str], optional): Optional list of hash algorithms to apply to shard files.
Defaults to ``None``.
size_limit (int, optional): Optional shard size limit, after which point to start a new
shard. If ``None``, puts everything in one shard. Defaults to ``1 << 26``.
size_limit (Union[int, str], optional): Optional shard size limit, after which point to start a new
shard. If ``None``, puts everything in one shard. Can specify bytes
human-readable format as well, for example ``"100kb"`` for 100 kilobyte
(100*1024) and so on. Defaults to ``1 << 26``.
**kwargs (Any): Additional settings for the Writer.
progress_bar (bool): Display TQDM progress bars for uploading output dataset files to
Expand All @@ -55,7 +57,7 @@ def __init__(self,
keep_local: bool = False,
compression: Optional[str] = None,
hashes: Optional[List[str]] = None,
size_limit: Optional[int] = 1 << 26,
size_limit: Optional[Union[int, str]] = 1 << 26,
**kwargs: Any) -> None:
super().__init__(out=out,
keep_local=keep_local,
Expand Down
10 changes: 6 additions & 4 deletions streaming/base/format/xsv/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import os
from copy import deepcopy
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

import numpy as np
from typing_extensions import Self
Expand All @@ -31,8 +31,10 @@ class XSVReader(SplitReader):
raw_meta (FileInfo): Uncompressed meta file info.
samples (int): Number of samples in this shard.
separator (str): Separator character(s).
size_limit (int, optional): Optional shard size limit, after which point to start a new
shard. If None, puts everything in one shard.
size_limit (Union[int, str], optional): Optional shard size limit, after which point to start a new
shard. If None, puts everything in one shard. Can specify bytes
human-readable format as well, for example ``"100kb"`` for 100 kilobyte
(100*1024) and so on.
zip_data (FileInfo, optional): Compressed data file info.
zip_meta (FileInfo, optional): Compressed meta file info.
"""
Expand All @@ -50,7 +52,7 @@ def __init__(
raw_meta: FileInfo,
samples: int,
separator: str,
size_limit: Optional[int],
size_limit: Optional[Union[int, str]],
zip_data: Optional[FileInfo],
zip_meta: Optional[FileInfo],
) -> None:
Expand Down
8 changes: 5 additions & 3 deletions streaming/base/format/xsv/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@ class XSVWriter(SplitWriter):
``None``.
hashes (List[str], optional): Optional list of hash algorithms to apply to shard files.
Defaults to ``None``.
size_limit (int, optional): Optional shard size limit, after which point to start a new
shard. If None, puts everything in one shard. Defaults to ``None``.
size_limit (Union[int, str], optional): Optional shard size limit, after which point to start a new
shard. If None, puts everything in one shard. Can specify bytes
human-readable format as well, for example ``"100kb"`` for 100 kilobyte
(100*1024) and so on. Defaults to ``1 << 26``
**kwargs (Any): Additional settings for the Writer.
progress_bar (bool): Display TQDM progress bars for uploading output dataset files to
Expand All @@ -57,7 +59,7 @@ def __init__(self,
keep_local: bool = False,
compression: Optional[str] = None,
hashes: Optional[List[str]] = None,
size_limit: Optional[int] = 1 << 26,
size_limit: Optional[Union[int, str]] = 1 << 26,
**kwargs: Any) -> None:
super().__init__(out=out,
keep_local=keep_local,
Expand Down
Loading

0 comments on commit 6b0d045

Please sign in to comment.