Skip to content

Commit

Permalink
Add new glue dataset and example. update dataset. (PaddlePaddle#59)
Browse files Browse the repository at this point in the history
* Add new glue dataset and example. update dataset.

* update load_dataet() args name
  • Loading branch information
smallv0221 authored Mar 3, 2021
1 parent d15ff60 commit e76fa2f
Show file tree
Hide file tree
Showing 11 changed files with 438 additions and 158 deletions.
132 changes: 45 additions & 87 deletions examples/glue/run_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
from paddle.io import DataLoader
from paddle.metric import Metric, Accuracy, Precision, Recall

from paddlenlp.datasets import GlueCoLA, GlueSST2, GlueMRPC, GlueSTSB, GlueQQP, GlueMNLI, GlueQNLI, GlueRTE
from paddlenlp.data import Stack, Tuple, Pad
from paddlenlp.datasets import load_dataset
from paddlenlp.data import Stack, Tuple, Pad, Dict
from paddlenlp.data.sampler import SamplerHelper
from paddlenlp.transformers import BertForSequenceClassification, BertTokenizer
from paddlenlp.transformers import ElectraForSequenceClassification, ElectraTokenizer
Expand All @@ -40,14 +40,14 @@
logger = logging.getLogger(__name__)

TASK_CLASSES = {
"cola": (GlueCoLA, Mcc),
"sst-2": (GlueSST2, Accuracy),
"mrpc": (GlueMRPC, AccuracyAndF1),
"sts-b": (GlueSTSB, PearsonAndSpearman),
"qqp": (GlueQQP, AccuracyAndF1),
"mnli": (GlueMNLI, Accuracy),
"qnli": (GlueQNLI, Accuracy),
"rte": (GlueRTE, Accuracy),
"cola": Mcc,
"sst-2": Accuracy,
"mrpc": AccuracyAndF1,
"sts-b": PearsonAndSpearman,
"qqp": AccuracyAndF1,
"mnli": Accuracy,
"qnli": Accuracy,
"rte": Accuracy,
}

MODEL_CLASSES = {
Expand Down Expand Up @@ -211,66 +211,25 @@ def convert_example(example,
max_seq_length=512,
is_test=False):
"""convert a glue example into necessary features"""

def _truncate_seqs(seqs, max_seq_length):
if len(seqs) == 1: # single sentence
# Account for [CLS] and [SEP] with "- 2"
seqs[0] = seqs[0][0:(max_seq_length - 2)]
else: # Sentence pair
# Account for [CLS], [SEP], [SEP] with "- 3"
tokens_a, tokens_b = seqs
max_seq_length -= 3
while True: # Truncate with longest_first strategy
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_seq_length:
break
if len(tokens_a) > len(tokens_b):
tokens_a.pop()
else:
tokens_b.pop()
return seqs

def _concat_seqs(seqs, separators, seq_mask=0, separator_mask=1):
concat = sum((seq + sep for sep, seq in zip(separators, seqs)), [])
segment_ids = sum(
([i] * (len(seq) + len(sep))
for i, (sep, seq) in enumerate(zip(separators, seqs))), [])
if isinstance(seq_mask, int):
seq_mask = [[seq_mask] * len(seq) for seq in seqs]
if isinstance(separator_mask, int):
separator_mask = [[separator_mask] * len(sep) for sep in separators]
p_mask = sum((s_mask + mask
for sep, seq, s_mask, mask in zip(
separators, seqs, seq_mask, separator_mask)), [])
return concat, segment_ids, p_mask

if not is_test:
# `label_list == None` is for regression task
label_dtype = "int64" if label_list else "float32"
# Get the label
label = example[-1]
example = example[:-1]
# Create label maps if classification task
if label_list:
label_map = {}
for (i, l) in enumerate(label_list):
label_map[l] = i
label = label_map[label]
label = example['labels']
label = np.array([label], dtype=label_dtype)

# Tokenize raw text
if len(example) == 1:
example = tokenizer(example[0], max_seq_len=max_seq_length)
# Convert raw text to feature
if len(example) == 2:
example = tokenizer(example['sentence'], max_seq_len=max_seq_length)
else:
example = tokenizer(
example[0], text_pair=example[1], max_seq_len=max_seq_length)
example['sentence1'],
text_pair=example['sentence2'],
max_seq_len=max_seq_length)

if not is_test:
return example['input_ids'], example['token_type_ids'], len(example[
'input_ids']), label
return example['input_ids'], example['token_type_ids'], label
else:
return example['input_ids'], example['token_type_ids'], len(example[
'input_ids'])
return example['input_ids'], example['token_type_ids']


def do_train(args):
Expand All @@ -281,69 +240,67 @@ def do_train(args):
set_seed(args)

args.task_name = args.task_name.lower()
dataset_class, metric_class = TASK_CLASSES[args.task_name]
metric_class = TASK_CLASSES[args.task_name]
args.model_type = args.model_type.lower()
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]

train_dataset = dataset_class.get_datasets(["train"])
train_ds = load_dataset('glue', args.task_name, splits="train")
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)

trans_func = partial(
convert_example,
tokenizer=tokenizer,
label_list=train_dataset.get_labels(),
label_list=train_ds.label_list,
max_seq_length=args.max_seq_length)
train_dataset = train_dataset.apply(trans_func, lazy=True)
train_ds = train_ds.map(trans_func, lazy=True)
train_batch_sampler = paddle.io.DistributedBatchSampler(
train_dataset, batch_size=args.batch_size, shuffle=True)
train_ds, batch_size=args.batch_size, shuffle=True)
batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=tokenizer.pad_token_id), # input
Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # segment
Stack(), # length
Stack(dtype="int64" if train_dataset.get_labels() else "float32") # label
): [data for i, data in enumerate(fn(samples)) if i != 2]
Stack(dtype="int64" if train_ds.label_list else "float32") # label
): fn(samples)
train_data_loader = DataLoader(
dataset=train_dataset,
dataset=train_ds,
batch_sampler=train_batch_sampler,
collate_fn=batchify_fn,
num_workers=0,
return_list=True)
if args.task_name == "mnli":
dev_dataset_matched, dev_dataset_mismatched = dataset_class.get_datasets(
["dev_matched", "dev_mismatched"])
dev_dataset_matched = dev_dataset_matched.apply(trans_func, lazy=True)
dev_dataset_mismatched = dev_dataset_mismatched.apply(
trans_func, lazy=True)
dev_ds_matched, dev_ds_mismatched = load_dataset(
'glue', args.task_name, splits=["dev_matched", "dev_mismatched"])

dev_ds_matched = dev_ds_matched.map(trans_func, lazy=True)
dev_ds_mismatched = dev_ds_mismatched.map(trans_func, lazy=True)
dev_batch_sampler_matched = paddle.io.BatchSampler(
dev_dataset_matched, batch_size=args.batch_size, shuffle=False)
dev_ds_matched, batch_size=args.batch_size, shuffle=False)
dev_data_loader_matched = DataLoader(
dataset=dev_dataset_matched,
dataset=dev_ds_matched,
batch_sampler=dev_batch_sampler_matched,
collate_fn=batchify_fn,
num_workers=0,
return_list=True)
dev_batch_sampler_mismatched = paddle.io.BatchSampler(
dev_dataset_mismatched, batch_size=args.batch_size, shuffle=False)
dev_ds_mismatched, batch_size=args.batch_size, shuffle=False)
dev_data_loader_mismatched = DataLoader(
dataset=dev_dataset_mismatched,
dataset=dev_ds_mismatched,
batch_sampler=dev_batch_sampler_mismatched,
collate_fn=batchify_fn,
num_workers=0,
return_list=True)
else:
dev_dataset = dataset_class.get_datasets(["dev"])
dev_dataset = dev_dataset.apply(trans_func, lazy=True)
dev_ds = load_dataset('glue', args.task_name, splits='dev')
dev_ds = dev_ds.map(trans_func, lazy=True)
dev_batch_sampler = paddle.io.BatchSampler(
dev_dataset, batch_size=args.batch_size, shuffle=False)
dev_ds, batch_size=args.batch_size, shuffle=False)
dev_data_loader = DataLoader(
dataset=dev_dataset,
dataset=dev_ds,
batch_sampler=dev_batch_sampler,
collate_fn=batchify_fn,
num_workers=0,
return_list=True)

num_classes = 1 if train_dataset.get_labels() == None else len(
train_dataset.get_labels())
num_classes = 1 if train_ds.label_list == None else len(train_ds.label_list)
model = model_class.from_pretrained(
args.model_name_or_path, num_classes=num_classes)
if paddle.distributed.get_world_size() > 1:
Expand All @@ -368,8 +325,8 @@ def do_train(args):
if not any(nd in n for nd in ["bias", "norm"])
])

loss_fct = paddle.nn.loss.CrossEntropyLoss() if train_dataset.get_labels(
) else paddle.nn.loss.MSELoss()
loss_fct = paddle.nn.loss.CrossEntropyLoss(
) if train_ds.label_list else paddle.nn.loss.MSELoss()

metric = metric_class()

Expand All @@ -378,6 +335,7 @@ def do_train(args):
for epoch in range(args.num_train_epochs):
for step, batch in enumerate(train_data_loader):
global_step += 1

input_ids, segment_ids, labels = batch
logits = model(input_ids, segment_ids)
loss = loss_fct(logits, labels)
Expand All @@ -392,7 +350,7 @@ def do_train(args):
paddle.distributed.get_rank(), loss, optimizer.get_lr(),
args.logging_steps / (time.time() - tic_train)))
tic_train = time.time()
if global_step % args.save_steps == 0:
if global_step % args.save_steps == 0 or global_step == num_training_steps:
tic_eval = time.time()
if args.task_name == "mnli":
evaluate(model, loss_fct, metric, dev_data_loader_matched)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ def prepare_train_features(examples):
questions,
contexts,
stride=args.doc_stride,
pad_to_max_seq_len=True,
max_seq_len=args.max_seq_length)

# Let's label those examples!
Expand Down Expand Up @@ -154,9 +153,11 @@ def prepare_train_features(examples):
token_start_index += 1

# End token index of the current span in the text.
token_end_index = len(input_ids) - 2
token_end_index = len(input_ids) - 1
while sequence_ids[token_end_index] != 1:
token_end_index -= 1
# Minus one more to reach actual text
token_end_index -= 1

# Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
if not (offsets[token_start_index][0] <= start_char and
Expand Down
4 changes: 3 additions & 1 deletion examples/machine_reading_comprehension/SQuAD/run_squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,11 @@ def prepare_train_features(examples):
token_start_index += 1

# End token index of the current span in the text.
token_end_index = len(input_ids) - 2
token_end_index = len(input_ids) - 1
while sequence_ids[token_end_index] != 1:
token_end_index -= 1
# Minus one more to reach actual text
token_end_index -= 1

# Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
if not (offsets[token_start_index][0] <= start_char and
Expand Down
3 changes: 1 addition & 2 deletions paddlenlp/datasets/experimental/cmrc2018.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,10 @@ def _get_data(self, mode, **kwargs):
if not os.path.exists(fullname) or (data_hash and
not md5file(fullname) == data_hash):
get_path_from_url(URL, default_root)
fullname = os.path.join(default_root, filename)

return fullname

def _read(self, filename):
def _read(self, filename, *args):
with open(filename, "r", encoding="utf8") as f:
input_data = json.load(f)["data"]
for entry in input_data:
Expand Down
Loading

0 comments on commit e76fa2f

Please sign in to comment.