-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support sft mapdataset #8840
Merged
wawltor
merged 5 commits into
PaddlePaddle:develop
from
greycooker:support_sft_mapdataset
Aug 5, 2024
Merged
support sft mapdataset #8840
Changes from 2 commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,6 +29,7 @@ | |
import shutil | ||
import struct | ||
import time | ||
from dataclasses import fields | ||
from functools import lru_cache | ||
from itertools import accumulate | ||
|
||
|
@@ -68,6 +69,20 @@ | |
return None | ||
|
||
|
||
def make_sft_dataset(path, impl, dataclass, skip_warmup=False): | ||
if impl == "mmap" and SFT_MMapIndexedDataset.exists(path, dataclass): | ||
print_rank_0(" > building dataset index ...") | ||
start_time = time.time() | ||
sft_indexed_dataset = SFT_MMapIndexedDataset(path, dataclass, skip_warmup) | ||
print_rank_0(" > finished creating SFT indexed dataset in {:4f} " "seconds".format(time.time() - start_time)) | ||
print_rank_0(" number of samples: {}".format(len(sft_indexed_dataset.doc_idx) - 1)) | ||
|
||
return sft_indexed_dataset | ||
|
||
print(f"Unknown dataset implementation: {impl}") | ||
return None | ||
|
||
|
||
def dataset_exists(path, impl): | ||
if impl == "mmap": | ||
return MMapIndexedDataset.exists(path) | ||
|
@@ -120,6 +135,18 @@ | |
return prefix_path + ".idx" | ||
|
||
|
||
def sft_index_file_path(prefix_path): | ||
return os.path.join(prefix_path, "index.idx") | ||
|
||
|
||
def sft_data_file_path(prefix_path, dataclass): | ||
file_path_list = [] | ||
for field in fields(dataclass): | ||
file_path = os.path.join(prefix_path, f"{field.name}.bin") | ||
file_path_list.append(file_path) | ||
return file_path_list | ||
|
||
|
||
def data_file_path(prefix_path): | ||
return prefix_path + ".bin" | ||
|
||
|
@@ -548,13 +575,259 @@ | |
return os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)) | ||
|
||
|
||
class SFT_MMapIndexedDataset(paddle.io.Dataset): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里class采用驼峰命名,不要下划线。 |
||
class Index(object): | ||
_HDR_MAGIC = b"MMIDIDX\x00\x00" | ||
|
||
@classmethod | ||
def writer(cls, path, dtype): | ||
class _Writer(object): | ||
def __enter__(self): | ||
self._file = open(path, "wb") | ||
self._file.write(cls._HDR_MAGIC) | ||
self._file.write(struct.pack("<Q", 1)) | ||
self._file.write(struct.pack("<B", code(dtype))) | ||
|
||
return self | ||
|
||
@staticmethod | ||
def _get_pointers(sizes): | ||
dtype_size = dtype().itemsize | ||
address = 0 | ||
pointers = [] | ||
for size in sizes: | ||
pointers.append(address) | ||
address += size * dtype_size | ||
return pointers | ||
|
||
def write(self, sizes, doc_idx): | ||
|
||
pointers = self._get_pointers(sizes) | ||
self._file.write(struct.pack("<Q", len(sizes))) | ||
self._file.write(struct.pack("<Q", len(doc_idx))) | ||
|
||
sizes = np.array(sizes, dtype=np.int32) | ||
self._file.write(sizes.tobytes(order="C")) | ||
del sizes | ||
|
||
pointers = np.array(pointers, dtype=np.int64) | ||
self._file.write(pointers.tobytes(order="C")) | ||
del pointers | ||
|
||
doc_idx = np.array(doc_idx, dtype=np.int64) | ||
self._file.write(doc_idx.tobytes(order="C")) | ||
|
||
def __exit__(self, exc_type, exc_val, exc_tb): | ||
self._file.close() | ||
|
||
return _Writer() | ||
|
||
def __init__(self, path, skip_warmup=False): | ||
with open(path, "rb") as stream: | ||
magic_test = stream.read(9) | ||
assert self._HDR_MAGIC == magic_test, ( | ||
"Index file doesn't match expected format. " | ||
"Make sure that --dataset-impl is configured properly." | ||
) | ||
version = struct.unpack("<Q", stream.read(8)) | ||
assert (1,) == version | ||
|
||
(dtype_code,) = struct.unpack("<B", stream.read(1)) | ||
self._dtype = dtypes[dtype_code] | ||
self._dtype_size = self._dtype().itemsize | ||
|
||
self._len = struct.unpack("<Q", stream.read(8))[0] | ||
self._doc_count = struct.unpack("<Q", stream.read(8))[0] | ||
offset = stream.tell() | ||
|
||
if not skip_warmup: | ||
print_rank_0(" warming up index mmap file...") | ||
_warmup_mmap_file(path) | ||
gongel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
self._buffer_mmap = np.memmap(path, mode="r", order="C") | ||
self._buffer = memoryview(self._buffer_mmap) | ||
print_rank_0(" reading sizes...") | ||
self._sizes = np.frombuffer(self._buffer, dtype=np.int32, count=self._len, offset=offset) | ||
print_rank_0(" reading pointers...") | ||
self._pointers = np.frombuffer( | ||
self._buffer, dtype=np.int64, count=self._len, offset=offset + self._sizes.nbytes | ||
) | ||
print_rank_0(" reading document index...") | ||
self._doc_idx = np.frombuffer( | ||
self._buffer, | ||
dtype=np.int64, | ||
count=self._doc_count, | ||
offset=offset + self._sizes.nbytes + self._pointers.nbytes, | ||
) | ||
|
||
def __del__(self): | ||
self._buffer_mmap._mmap.close() | ||
del self._buffer_mmap | ||
|
||
@property | ||
def dtype(self): | ||
return self._dtype | ||
|
||
@property | ||
def sizes(self): | ||
return self._sizes | ||
|
||
@property | ||
def doc_idx(self): | ||
return self._doc_idx | ||
|
||
@lru_cache(maxsize=8) | ||
def __getitem__(self, i): | ||
return self._pointers[i], self._sizes[i] | ||
|
||
def __len__(self): | ||
return self._doc_count - 1 | ||
|
||
def __init__(self, path, dataclass, skip_warmup=False): | ||
super().__init__() | ||
self._dataclass = dataclass | ||
self._path = None | ||
self._index = None | ||
self._bin_buffer = None | ||
|
||
self._do_init(path, skip_warmup) | ||
|
||
def __getstate__(self): | ||
return self._path | ||
|
||
def __setstate__(self, state): | ||
self._do_init(state, skip_warmup=True) | ||
|
||
def _do_init(self, path, skip_warmup): | ||
self._path = path | ||
if not self.exists(path, self._dataclass): | ||
raise ValueError("Missing file, %s" % (path)) | ||
|
||
self._index = self.Index(sft_index_file_path(self._path), skip_warmup) | ||
if not skip_warmup: | ||
print_rank_0(" warming up data mmap file...") | ||
for data_file in sft_data_file_path(self._path, self._dataclass): | ||
_warmup_mmap_file(data_file) | ||
print_rank_0(" creating numpy buffer of mmap...") | ||
|
||
self._bin_buffer_mmap_dict = {} | ||
self._bin_buffer_dict = {} | ||
for data_file in sft_data_file_path(self._path, self._dataclass): | ||
self._bin_buffer_mmap_dict[data_file] = np.memmap(data_file, mode="r", order="C") | ||
self._bin_buffer_dict[data_file] = memoryview(self._bin_buffer_mmap_dict[data_file]) | ||
print_rank_0(" creating memory view of numpy buffer...") | ||
|
||
def __del__(self): | ||
for key, value in self._bin_buffer_mmap_dict.items(): | ||
value._mmap.close() | ||
for key, value in self._bin_buffer_dict.items(): | ||
del value | ||
del self._index | ||
|
||
def __len__(self): | ||
return len(self._index) | ||
|
||
def __getitem__(self, idx): | ||
def get_index(idx): | ||
doc_idx = self._index.doc_idx | ||
start_sentence, end_sentence = doc_idx[idx], doc_idx[idx + 1] | ||
start_pointers, _ = self._index[start_sentence] | ||
length_list = self._index._sizes[start_sentence:end_sentence] | ||
|
||
dataclass_fields = fields(self._dataclass) | ||
dataclass_list = [] | ||
sequence_offset = start_pointers | ||
scalar_offset = doc_idx[idx] * np.dtype(self._index.dtype).itemsize | ||
|
||
for length in length_list: | ||
field_data = {field.name: [] for field in dataclass_fields} | ||
for field in dataclass_fields: | ||
bin_buffer = self._bin_buffer_dict[os.path.join(self._path, f"{field.name}.bin")] | ||
if field.type != int: | ||
data = np.frombuffer(bin_buffer, dtype=self._index.dtype, count=length, offset=sequence_offset) | ||
field_data[field.name] = data.tolist() | ||
else: | ||
data = np.frombuffer(bin_buffer, dtype=self._index.dtype, count=1, offset=scalar_offset) | ||
field_data[field.name] = int(data[0]) | ||
|
||
dataclass_list.append(self._dataclass(**field_data)) | ||
|
||
sequence_offset += length * np.dtype(self._index.dtype).itemsize | ||
scalar_offset += np.dtype(self._index.dtype).itemsize | ||
return dataclass_list | ||
|
||
if isinstance(idx, (int, np.integer)): | ||
return get_index(idx) | ||
elif isinstance(idx, slice): | ||
start, stop, step = idx.indices(len(self)) | ||
if step != 1: | ||
raise ValueError("Slices into indexed_dataset must be contiguous") | ||
return [get_index(idx) for idx in range(start, stop)] | ||
|
||
@property | ||
def sizes(self): | ||
return self._index.sizes | ||
|
||
@property | ||
def doc_idx(self): | ||
return self._index.doc_idx | ||
|
||
def get_doc_idx(self): | ||
return self._index._doc_idx | ||
|
||
def set_doc_idx(self, doc_idx_): | ||
self._index._doc_idx = doc_idx_ | ||
|
||
@property | ||
def supports_prefetch(self): | ||
return False | ||
|
||
@staticmethod | ||
def exists(path, dataclass): | ||
file_path_list = sft_data_file_path(path, dataclass) | ||
file_path_list.append(sft_index_file_path(path)) | ||
for file_path in file_path_list: | ||
if not os.path.exists(file_path): | ||
return False | ||
return True | ||
|
||
|
||
def make_builder(out_file, impl, save_dtype, loss_mask_file=None): | ||
if impl == "mmap": | ||
return MMapIndexedDatasetBuilder(out_file, dtype=save_dtype, loss_mask_file=loss_mask_file) | ||
else: | ||
return IndexedDatasetBuilder(out_file, dtype=save_dtype) | ||
|
||
|
||
class SFT_MMapIndexedDatasetBuilder(object): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个命名同样 |
||
def __init__(self, output_file_dict, dtype): | ||
self._data_file_dict = {} | ||
for key, filename in output_file_dict.items(): | ||
self._data_file_dict[key] = open(filename, "wb") | ||
self.output_file_dict = output_file_dict | ||
self._dtype = dtype | ||
self._sizes = [] | ||
self._doc_idx = [0] | ||
|
||
def add_item(self, sequence): | ||
add_sequence_len = False | ||
for key in self._data_file_dict.keys(): | ||
tensor = np.array(getattr(sequence, key), dtype=self._dtype) | ||
if tensor.size > 1 and not add_sequence_len: | ||
self._sizes.append(tensor.size) | ||
add_sequence_len = True | ||
self._data_file_dict[key].write(tensor.tobytes(order="C")) | ||
|
||
def end_document(self): | ||
self._doc_idx.append(len(self._sizes)) | ||
|
||
def finalize(self, index_file): | ||
for key, filename in self._data_file_dict.items(): | ||
filename.close() | ||
with SFT_MMapIndexedDataset.Index.writer(index_file, self._dtype) as index: | ||
index.write(self._sizes, self._doc_idx) | ||
|
||
|
||
class MMapIndexedDatasetBuilder(object): | ||
def __init__(self, out_file, dtype, loss_mask_file=None): | ||
self._data_file = open(out_file, "wb") | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
要么就只支持mmap的吧,不用判断了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已经修改,不是mmap直接报错
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议:make_sft_dataset(path, dataclass, skip_warmup=False, impl=“mmap”)