From 236aadc62b7613e62109aa2d860bd74bbb4b0309 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Mon, 23 Jan 2023 09:11:14 -0800 Subject: [PATCH] Nuke concat option. --- streaming/text/c4.py | 59 +++----------------------------------------- 1 file changed, 3 insertions(+), 56 deletions(-) diff --git a/streaming/text/c4.py b/streaming/text/c4.py index 733831f7b..6a3ca86db 100644 --- a/streaming/text/c4.py +++ b/streaming/text/c4.py @@ -7,7 +7,7 @@ the `Common Crawl `_ dataset. """ -from typing import Any, Dict, Iterator, Optional +from typing import Any, Dict, Optional from transformers.models.auto.tokenization_auto import AutoTokenizer @@ -64,9 +64,8 @@ def __init__(self, shuffle_seed: int = 9176, num_canonical_nodes: Optional[int] = None, batch_size: Optional[int] = None): - if group_method not in {'truncate', 'concat'}: - raise ValueError( - f"group_method='{group_method}' must be one of ['truncate', 'concat'].") + if group_method not in {'truncate'}: + raise ValueError(f"group_method='{group_method}' must be one of {'truncate'}.") super().__init__(local, remote, split, shuffle, predownload, keep_zip, download_retry, download_timeout, validate_hash, shuffle_seed, num_canonical_nodes, @@ -80,8 +79,6 @@ def __init__(self, if self.tokenizer.pad_token is None: # Some tokenizers (e.g. GPT2 tokenizer) have no padding token which causes bugs self.tokenizer.pad_token = self.tokenizer.eos_token - # suppress warnings when using group_method='concat' and no truncation - self.tokenizer.model_max_length = int(1e30) def _tokenize(self, text_sample: Dict[str, Any]): """Apply the tokenizer to a sample. @@ -93,10 +90,6 @@ def _tokenize(self, text_sample: Dict[str, Any]): truncation = True padding = 'max_length' max_length = self.max_seq_len - elif self.group_method == 'concat': - truncation = False - padding = False - max_length = None else: raise ValueError(f'Got unknown group_method={self.group_method}.') return self.tokenizer(text_sample['text'], @@ -117,49 +110,3 @@ def __getitem__(self, idx: int) -> Any: token_sample = self._tokenize(text_sample) # Skip any token grouping, currently only supporting group_method='truncate' return token_sample - - def __iter__(self) -> Iterator[Any]: - """Iterable over samples. - - Since concatenating samples has a custom behavior, it requires extending the - parent iterator class. - - For `group_method = truncate`, simply return the token sample. - For `group_method = concat`, keep fetching token samples until it fills up the max_seq_len. - - Yields: - Iterator[Any]: Sample iterator - """ - if self.group_method == 'truncate': - yield from super().__iter__() - elif self.group_method == 'concat': - buffer = {} - while True: - iterator = super().__iter__() - for sample in iterator: - for k, v in sample.items(): - buffer[k] = buffer.get(k, []) + v - while len(buffer['input_ids']) >= self.max_seq_len: - concat_sample = {} - for k, v in buffer.items(): - concat_sample[k] = v[:self.max_seq_len] - buffer[k] = v[self.max_seq_len:] - yield concat_sample - else: - raise ValueError(f'Got unknown group_method={self.group_method}.') - - def __len__(self) -> Optional[int]: - """Number of samples in a dataset. - - For `group_method = truncate`, return the number of samples. - For `group_method = concat`, since it repeat forever, it doesn't have any defined length. - - Returns: - Optional[int]: Number of samples - """ - if self.group_method == 'truncate': - return super().__len__() - elif self.group_method == 'concat': - return None - else: - raise ValueError(f'Got unknown group_method={self.group_method}.')