Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/mixture #86

Merged
merged 5 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 50 additions & 10 deletions data_juicer/format/mixture_formatter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from itertools import chain, repeat
from typing import List, Tuple, Union

import numpy as np
Expand All @@ -17,6 +18,7 @@ def __init__(self,
suffixes: Union[str, List[str], Tuple[str]] = None,
text_keys=None,
add_suffix=False,
max_samples=None,
**kwargs):
"""
Initialization method.
Expand All @@ -28,9 +30,30 @@ def __init__(self,
:param text_keys: key names of field that stores sample text.
:param add_suffix: whether to add the file suffix to dataset
meta info
:param max_samples: max samples number of mixed dataset.
:param kwargs: extra args
"""

data_prefixes, weights = self._get_weight(data_prefix=dataset_path)
sample_numbers = [0] * len(weights)
if max_samples is not None:
# Normalize weights.
weights = np.array(weights, dtype=np.float64)
sum_weights = np.sum(weights)
assert sum_weights > 0.0
weights /= sum_weights
sample_num_per_dataset = [
int(np.ceil(max_samples * weight)) for weight in weights
]

# Adjust
acc_sample_numbers = 0
for i in range(len(sample_num_per_dataset)):
sample_numbers[i] = min(sample_num_per_dataset[i],
max_samples - acc_sample_numbers)
acc_sample_numbers += sample_numbers[i]

self.sample_numbers = sample_numbers
self.weights = weights
self.formatters = [
load_formatter(dataset_path=data_prefix,
Expand All @@ -54,7 +77,7 @@ def _get_weight(self, data_prefix):

for i in range(len(data_prefix)):
try:
value = float(data_prefix[i])
value = max(float(data_prefix[i]), 0.0)
weights.append(value)
except: # noqa: E722
value = data_prefix[i].strip()
Expand All @@ -65,21 +88,36 @@ def _get_weight(self, data_prefix):
prefixes.append(value)
return prefixes, weights

def _random_sample(self, dataset, weight=1.0, seed=None):
def _random_sample(self, dataset, weight=1.0, sample_number=0, seed=None):
"""
Randomly sample a subset from a dataset with weight.
Randomly sample a subset from a dataset with weight or number,
if sample number is bigger than 0, we will use sample
number instead of weight.
:param dataset: a HuggingFace dataset
:param weight: sample ratio of dataset
:param sample_number: sample number of dataset
:param seed: random sample seed, if None, 42 as default
:return: a subset of dataset
"""
if seed is None:
seed = 42
num_samples = min(int(np.ceil(dataset.num_rows * weight)),
dataset.num_rows)
if num_samples == dataset.num_rows:

ds_samples = dataset.num_rows
if sample_number <= 0:
sample_number = int(np.ceil(ds_samples * weight))

if sample_number == ds_samples:
return dataset
return dataset.shuffle(seed=seed).select(range(num_samples))

sample_index = range(sample_number)

n_repeat = int(np.ceil(sample_number / ds_samples)) - 1
if n_repeat > 0:
remain_samples = sample_number - n_repeat * ds_samples
sample_index = chain(*repeat(range(ds_samples), n_repeat),
range(remain_samples))

return dataset.shuffle(seed=seed).select(sample_index)

def load_dataset(self, num_proc: int = 1, global_cfg=None) -> Dataset:
"""
Expand All @@ -90,11 +128,13 @@ def load_dataset(self, num_proc: int = 1, global_cfg=None) -> Dataset:
:return: mixed dataset
"""
dataset_list = []
for weight, formatter in zip(self.weights, self.formatters):
for weight, sample_num, formatter in zip(self.weights,
self.sample_numbers,
self.formatters):
dataset = formatter.load_dataset(num_proc, global_cfg)
sampled = self._random_sample(dataset, weight)
sampled = self._random_sample(dataset, weight, sample_num)
logger.info(f'sampled {len(sampled)} from '
f'{len(dataset)} with weight {weight}')
f'{len(dataset)}')
dataset_list.append(sampled)

from data_juicer.core.data import NestedDataset
Expand Down
4 changes: 4 additions & 0 deletions tests/format/data/structured/demo-dataset.jsonl
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
{"text": "Today is Sunday and it's a happy day!", "meta": {"src": "Arxiv", "date": "2023-04-27", "version": "1.0"}}
{"text": "Do you need a cup of coffee?", "meta": {"src": "code", "author": "xxx"}}
{"text": "Today is Sunday and it's a happy day!", "meta": {"src": "Arxiv", "date": "2023-04-27", "version": "1.0"}}
{"text": "Do you need a cup of coffee?", "meta": {"src": "code", "author": "xxx"}}
{"text": "Today is Sunday and it's a happy day!", "meta": {"src": "Arxiv", "date": "2023-04-27", "version": "1.0"}}
{"text": "Do you need a cup of coffee?", "meta": {"src": "code", "author": "xxx"}}
78 changes: 78 additions & 0 deletions tests/format/test_mixture_formatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import os
import unittest

from data_juicer.format.mixture_formatter import MixtureFormatter


class MixtureFormatterTest(unittest.TestCase):

def setUp(self):
self._path = os.path.join(os.path.dirname(os.path.realpath(__file__)),
'data', 'structured')
self._file = os.path.join(self._path, 'demo-dataset.jsonl')
self._file2 = self._file

def test_only_file(self):
formatter = MixtureFormatter(self._file)
ds = formatter.load_dataset()
self.assertEqual(len(ds), 6)
self.assertEqual(list(ds.features.keys()), ['text', 'meta'])

def test_sample_weight(self):
formatter = MixtureFormatter('0.5 ' + self._file)
ds = formatter.load_dataset()
self.assertEqual(len(ds), 3)
self.assertEqual(list(ds.features.keys()), ['text', 'meta'])

def test_sample_number(self):
max_samples = 2
formatter = MixtureFormatter(self._file, max_samples=max_samples)
ds = formatter.load_dataset()
self.assertEqual(len(ds), max_samples)
self.assertEqual(list(ds.features.keys()), ['text', 'meta'])

def test_sample_number_weight(self):
max_samples = 2
formatter = MixtureFormatter('0.5 ' + self._file, max_samples=max_samples)
ds = formatter.load_dataset()
self.assertEqual(len(ds), max_samples)
self.assertEqual(list(ds.features.keys()), ['text', 'meta'])

def test_multi_datasets_without_weight(self):
data_path = self._file + ' ' + self._file2
formatter = MixtureFormatter(data_path)
ds = formatter.load_dataset()
self.assertEqual(len(ds), 12)
self.assertEqual(list(ds.features.keys()), ['text', 'meta'])

def test_multi_datasets_with_weight(self):
data_path = self._file + ' ' + self._file2
formatter = MixtureFormatter(data_path)
ds = formatter.load_dataset()
self.assertEqual(len(ds), 12)
self.assertEqual(list(ds.features.keys()), ['text', 'meta'])

def test_multi_datasets_with_one_weight(self):
data_path = '0.5 ' + self._file + ' ' + self._file2
formatter = MixtureFormatter(data_path)
ds = formatter.load_dataset()
self.assertEqual(len(ds), 9)
self.assertEqual(list(ds.features.keys()), ['text', 'meta'])

def test_multi_datasets_with_weight(self):
data_path = '0.5 ' + self._file + ' 0.5 ' + self._file2
formatter = MixtureFormatter(data_path)
ds = formatter.load_dataset()
self.assertEqual(len(ds), 6)
self.assertEqual(list(ds.features.keys()), ['text', 'meta'])

def test_multi_datasets_with_sample(self):
max_samples = 7
data_path = '0.5 ' + self._file + ' 0.5 ' + self._file2
formatter = MixtureFormatter(data_path, max_samples=max_samples)
ds = formatter.load_dataset()
self.assertEqual(len(ds), max_samples)
self.assertEqual(list(ds.features.keys()), ['text', 'meta'])

if __name__ == '__main__':
unittest.main()
3 changes: 1 addition & 2 deletions tests/format/test_unify_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,8 +366,7 @@ def test_hetero_meta(self):
'author': 'xxx'
}
}]
unified_sample_list = ds.to_list()
self.assertEqual(unified_sample_list, sample)

# test nested and missing field for the following cases:
# 1. first row, then column
unified_sample_first = ds[0]
Expand Down
7 changes: 6 additions & 1 deletion tools/postprocess/data_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ def parse_args():
'size of each shard won\'t larger than the '
'export_shard_size')

parser.add_argument('--max_samples',
type=int,
default=None,
help='Number of samples of mixed dataset.')

parser.add_argument('--num_proc',
type=int,
default=4,
Expand All @@ -58,7 +63,7 @@ def run_mixture():
"""
args = parse_args()
data_path = ' '.join(args.data_path)
formatter = load_formatter(data_path)
formatter = load_formatter(data_path, max_samples=args.max_samples)
dataset = formatter.load_dataset(args.num_proc)
exporter = Exporter(export_path=args.export_path,
export_shard_size=args.export_shard_size,
Expand Down