-
Notifications
You must be signed in to change notification settings - Fork 4
/
dataset.py
124 lines (98 loc) · 4.77 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import torch
import json
from torch.utils.data import Dataset
def load_data(path):
with open(path, "r") as fp:
data = fp.read().strip().split("\n")
return data
def print_dataset_example(input_input_ids, label_input_ids, tokenizer):
print("input_ids",input_input_ids)
print("input_tokens", tokenizer.convert_ids_to_tokens(input_input_ids))
print("inputs", tokenizer.decode(input_input_ids))
print("label_ids", label_input_ids)
print("label_tokens", tokenizer.convert_ids_to_tokens(label_input_ids))
print("labels", tokenizer.decode(label_input_ids))
class NerCollate:
def __init__(self, args, tokenizer):
self.max_source_length = args.max_source_length
self.max_target_length = args.max_target_length
self.instruct_column = args.instruct_column
self.query_column = args.query_column
self.response_column = args.response_column
self.ignore_pad_token_for_loss = args.ignore_pad_token_for_loss
self.history_column = None
self.tokenizer = tokenizer
self.max_seq_length = self.max_source_length + self.max_target_length
def collate_fn(self, batch):
model_inputs = {
"input_ids": [],
"labels": [],
}
for example in batch:
if isinstance(example, str):
example = json.loads(example)
if example[self.query_column] and example[self.response_column]:
instruct = example[self.instruct_column]
query, answer = example[self.query_column], example[self.response_column]
if self.history_column is None:
prompt = instruct + "\n" + query
else:
prompt = ""
history = example[self.history_column]
for turn_idx, (old_query, response) in enumerate(history):
prompt += "[Round {}]\n问:{}\n答:{}\n".format(turn_idx, old_query, response)
prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
a_ids = self.tokenizer.encode(text=prompt, add_special_tokens=False)
b_ids = self.tokenizer.encode(text=answer, add_special_tokens=False)
# print_dataset_example(a_ids, b_ids, self.tokenizer)
if len(a_ids) > self.max_source_length - 1:
a_ids = a_ids[: self.max_source_length - 1]
if len(b_ids) > self.max_target_length - 2:
b_ids = b_ids[: self.max_target_length - 2]
input_ids = self.tokenizer.build_inputs_with_special_tokens(a_ids, b_ids)
# print(input_ids)
# print(self.tokenizer.convert_ids_to_tokens(input_ids))
# print(self.tokenizer.decode(input_ids))
context_length = input_ids.index(self.tokenizer.bos_token_id) # sop
mask_position = context_length - 1
labels = [-100] * context_length + input_ids[mask_position+1:]
pad_len = self.max_seq_length - len(input_ids)
input_ids = input_ids + [self.tokenizer.pad_token_id] * pad_len
labels = labels + [self.tokenizer.pad_token_id] * pad_len
if self.ignore_pad_token_for_loss:
labels = [(l if l != self.tokenizer.pad_token_id else -100) for l in labels]
# print(labels)
model_inputs["input_ids"].append(input_ids)
model_inputs["labels"].append(labels)
model_inputs["input_ids"] = torch.tensor(model_inputs["input_ids"])
model_inputs["labels"] = torch.tensor(model_inputs["labels"])
return model_inputs
if __name__ == "__main__":
class Args:
max_source_length = 128
max_target_length = 128
instruct_column = "instruct"
query_column = "query"
response_column = "answer"
ignore_pad_token_for_loss = True
train_path = "data/msra/instruct_data/train.txt"
args = Args()
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True)
data = load_data(args.train_path)
print(data[0])
ner_collate = NerCollate(args, tokenizer)
# from torch.utils.data import DataLoader
# train_dataloader = DataLoader(data,
# batch_size=1,
# shuffle=False,
# drop_last=True,
# num_workers=0,
# collate_fn=ner_collate.collate_fn)
# for step, batch in enumerate(train_dataloader):
# input_ids = batch["input_ids"]
# labels = batch["labels"]
# print(input_ids.shape, labels.shape)
# break
train_dataset = ner_collate.collate_fn(data)
print(train_dataset["input_ids"][0])