Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nuke concat option. #129

Merged
merged 1 commit into from
Jan 23, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 3 additions & 56 deletions streaming/text/c4.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
the `Common Crawl <https://commoncrawl.org>`_ dataset.
"""

from typing import Any, Dict, Iterator, Optional
from typing import Any, Dict, Optional

from transformers.models.auto.tokenization_auto import AutoTokenizer

Expand Down Expand Up @@ -64,9 +64,8 @@ def __init__(self,
shuffle_seed: int = 9176,
num_canonical_nodes: Optional[int] = None,
batch_size: Optional[int] = None):
if group_method not in {'truncate', 'concat'}:
raise ValueError(
f"group_method='{group_method}' must be one of ['truncate', 'concat'].")
if group_method not in {'truncate'}:
raise ValueError(f"group_method='{group_method}' must be one of {'truncate'}.")

super().__init__(local, remote, split, shuffle, predownload, keep_zip, download_retry,
download_timeout, validate_hash, shuffle_seed, num_canonical_nodes,
Expand All @@ -80,8 +79,6 @@ def __init__(self,
if self.tokenizer.pad_token is None:
# Some tokenizers (e.g. GPT2 tokenizer) have no padding token which causes bugs
self.tokenizer.pad_token = self.tokenizer.eos_token
# suppress warnings when using group_method='concat' and no truncation
self.tokenizer.model_max_length = int(1e30)

def _tokenize(self, text_sample: Dict[str, Any]):
"""Apply the tokenizer to a sample.
Expand All @@ -93,10 +90,6 @@ def _tokenize(self, text_sample: Dict[str, Any]):
truncation = True
padding = 'max_length'
max_length = self.max_seq_len
elif self.group_method == 'concat':
truncation = False
padding = False
max_length = None
else:
raise ValueError(f'Got unknown group_method={self.group_method}.')
return self.tokenizer(text_sample['text'],
Expand All @@ -117,49 +110,3 @@ def __getitem__(self, idx: int) -> Any:
token_sample = self._tokenize(text_sample)
# Skip any token grouping, currently only supporting group_method='truncate'
return token_sample

def __iter__(self) -> Iterator[Any]:
"""Iterable over samples.

Since concatenating samples has a custom behavior, it requires extending the
parent iterator class.

For `group_method = truncate`, simply return the token sample.
For `group_method = concat`, keep fetching token samples until it fills up the max_seq_len.

Yields:
Iterator[Any]: Sample iterator
"""
if self.group_method == 'truncate':
yield from super().__iter__()
elif self.group_method == 'concat':
buffer = {}
while True:
iterator = super().__iter__()
for sample in iterator:
for k, v in sample.items():
buffer[k] = buffer.get(k, []) + v
while len(buffer['input_ids']) >= self.max_seq_len:
concat_sample = {}
for k, v in buffer.items():
concat_sample[k] = v[:self.max_seq_len]
buffer[k] = v[self.max_seq_len:]
yield concat_sample
else:
raise ValueError(f'Got unknown group_method={self.group_method}.')

def __len__(self) -> Optional[int]:
"""Number of samples in a dataset.

For `group_method = truncate`, return the number of samples.
For `group_method = concat`, since it repeat forever, it doesn't have any defined length.

Returns:
Optional[int]: Number of samples
"""
if self.group_method == 'truncate':
return super().__len__()
elif self.group_method == 'concat':
return None
else:
raise ValueError(f'Got unknown group_method={self.group_method}.')