Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

[Numpy] Refactor Roberta #1269

Merged
merged 11 commits into from
Jul 21, 2020
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 35 additions & 7 deletions scripts/question_answering/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ python run_squad.py \
--epochs 3 \
--lr 2e-5 \
--warmup_ratio 0.1 \
--wd=0.01 \
--wd 0.01 \
--max_seq_length 512 \
--max_grad_norm 0.1 \
--overwrite_cache \
Expand All @@ -60,9 +60,9 @@ or evaluate SQuAD1.1 based on a SQuAD2.0 fine-tuned checkpoint as
```bash
python run_squad.py \
--model_name ${MODEL_NAME} \
--data_dir=squad \
--output_dir=${OUT_DIR} \
--param_checkpoint=${CKPT_PATH} \
--data_dir squad \
--output_dir ${OUT_DIR} \
--param_checkpoint ${CKPT_PATH} \
--version 2.0 \
--do_eval \
--gpus 0,1,2,3 \
Expand Down Expand Up @@ -90,18 +90,43 @@ python run_squad.py \
--lr 3e-4 \
--layerwise_decay 0.8 \
--warmup_ratio 0.1 \
--wd=0 \
--wd 0 \
--max_seq_length 512 \
--max_grad_norm 0.1 \
```

For RoBERTa and XLMR, we remove 'segment_ids' and replace `[CLS]` and `[SEP]` with
`<s>` and `</s>` which stand for the beginning and end of sentences respectively in original purpose.

```bash
VERSION=2.0 # Either 2.0 or 1.1
MODEL_NAME=fairseq_roberta_large

python run_squad.py \
--model_name ${MODEL_NAME} \
--data_dir squad \
--output_dir fintune_${MODEL_NAME}_squad_${VERSION} \
--version ${VERSION} \
--do_eval \
--do_train \
--batch_size 2 \
--num_accumulated 6 \
--gpus 0,1,2,3 \
--epochs 3 \
--lr 3e-5 \
--warmup_ratio 0.2 \
--wd 0.01 \
--max_seq_length 512 \
--max_grad_norm 0.1 \
```

### Results
We reproduced the ALBERT model which is released by Google, and fine-tune the the SQuAD with single models. ALBERT Version 2 are pre-trained without the dropout mechanism but with extra training steps compared to the version 1 (see the [original paper](https://arxiv.org/abs/1909.11942) for details).

Fine-tuning the listed models with hyper-parameter learning rate 2e-5, epochs 3, warmup ratio 0.1 and max gradient norm 0.1 (as shown in command). Notice that the `batch_size` is set for each GPU and the global batch size is 48 for all experiments, besides that gradient accumulation (`num_accumulated`) is supported in the case of out of memory.

Performance are shown in the table below, in which the SQuAD1.1 are evaluated with SQuAD2.0 checkpoints.
Notice that the standard metrics of SQuAD are EM and F1. The former is an exact match score between predictions and references, while the latter is a token-level f1 score in which the common tokens are considered as True Positives.

|Reproduced ALBERT Models (F1/EM) | SQuAD 1.1 dev | SQuAD 2.0 dev |
|----------------------------------|---------------|---------------|
Expand All @@ -119,7 +144,7 @@ For reference, we've included the results from Google's Original Experiments
|ALBERT xlarge (googleresearch/albert) | 92.9/86.4 | 87.9/84.1 |
|ALBERT xxlarge (googleresearch/albert) | 94.6/89.1 | 89.8/86.9 |

For BERT and ELECTRA model, the results on SQuAD1.1 and SQuAD2.0 are given as follows.
For the reset pretrained models, the results on SQuAD1.1 and SQuAD2.0 are given as follows.

| Model Name | SQuAD1.1 dev | SQuAD2.0 dev |
|--------------------------|---------------|--------------|
Expand All @@ -128,8 +153,9 @@ For BERT and ELECTRA model, the results on SQuAD1.1 and SQuAD2.0 are given as fo
|ELECTRA small | 85.42/78.95 | 74.44/71.86 |
|ELECTRA base | 92.63/87.34 | 86.34/83.62 |
|ELECTRA large | 94.95/89.94 | 90.59/88.13 |
|RoBERTa large | 94.58/88.86 | 89.01/85.93 |

For reference, we have also included the results of Google's original version
For reference, we have also included the results of original version from Google and Fairseq

| Model Name | SQuAD1.1 dev | SQuAD2.0 dev |
|--------------------------|----------------|---------------|
Expand All @@ -138,5 +164,7 @@ For reference, we have also included the results of Google's original version
|Google ELECTRA base | - /75.8 | - /70.1 |
|Google ELECTRA base | - /86.8 | - /83.7 |
|Google ELECTRA large | - /89.7 | - /88.1 |
|Fairseq RoBERTa large | 94.6/88.9 | 89.4/86.5 |


All experiments done on AWS P3.8xlarge (4 x NVIDIA Tesla V100 16 GB)
2 changes: 1 addition & 1 deletion scripts/question_answering/eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def compute_exact(a_gold, a_pred):
def compute_f1(a_gold, a_pred):
"""
Compute the token-level f1 scores in which the common tokens are considered
as True Postives. Precision and recall are percentages of the number of
as True Positives. Precision and recall are percentages of the number of
common tokens in the prediction and groud truth, respectively.
"""
gold_toks = get_tokens(a_gold)
Expand Down
46 changes: 40 additions & 6 deletions scripts/question_answering/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,23 @@ class ModelForQABasic(HybridBlock):
Here, we directly use the backbone network to extract the contextual embeddings and use
another dense layer to map the contextual embeddings to the start scores and end scores.

use_segmentation is used to mark whether we segment the input sentence. In RoBERTa and XLMR,
this flag is set to True, then the QA model no longer accept `token_types` as valid input.

- use_segmentation=True:
tokens : <CLS> Question <SEP> Context <SEP>
token_typess: 0 0 0 1 1

- use_segmentation=False:
tokens : <CLS> Question <SEP> Context <SEP>
token_typess: None
"""
def __init__(self, backbone, weight_initializer=None, bias_initializer=None):
def __init__(self, backbone, weight_initializer=None, bias_initializer=None,
use_segmentation=True):
sxjscience marked this conversation as resolved.
Show resolved Hide resolved
super().__init__()

self.backbone = backbone
self.use_segmentation = use_segmentation
self.qa_outputs = nn.Dense(units=2, flatten=False,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer)
Expand Down Expand Up @@ -50,8 +63,11 @@ def hybrid_forward(self, F, tokens, token_types, valid_length, p_mask):
The log-softmax scores that the position is the end position.
"""
# Get contextual embedding with the shape (batch_size, sequence_length, C)
contextual_embedding = self.backbone(tokens, token_types, valid_length)
scores = self.qa_outputs(contextual_embedding)
if self.use_segmentation:
contextual_embeddings = self.backbone(tokens, token_types, valid_length)
else:
contextual_embeddings = self.backbone(tokens, valid_length)
scores = self.qa_outputs(contextual_embeddings)
start_scores = scores[:, :, 0]
end_scores = scores[:, :, 1]
start_logits = masked_logsoftmax(F, start_scores, mask=p_mask, axis=-1)
Expand All @@ -72,11 +88,23 @@ class ModelForQAConditionalV1(HybridBlock):

In the inference phase, we are able to use beam search to do the inference.

use_segmentation is used to mark whether we segment the input sentence. In RoBERTa and XLMR,
this flag is set to True, then the QA model no longer accept `token_types` as valid input.

- use_segmentation=True:
tokens : <CLS> Question <SEP> Context <SEP>
token_typess: 0 0 0 1 1

- use_segmentation=False:
tokens : <CLS> Question <SEP> Context <SEP>
token_typess: None
"""
def __init__(self, backbone, units=768, layer_norm_eps=1E-12, dropout_prob=0.1,
activation='tanh', weight_initializer=None, bias_initializer=None):
activation='tanh', weight_initializer=None, bias_initializer=None,
use_segmentation=True):
super().__init__()
self.backbone = backbone
self.use_segmentation = use_segmentation
self.start_scores = nn.Dense(1, flatten=False,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer)
Expand Down Expand Up @@ -209,7 +237,10 @@ def hybrid_forward(self, F, tokens, token_types, valid_length, p_mask, start_pos
Shape (batch_size, sequence_length)
answerable_logits
"""
contextual_embeddings = self.backbone(tokens, token_types, valid_length)
if self.use_segmentation:
contextual_embeddings = self.backbone(tokens, token_types, valid_length)
else:
contextual_embeddings = self.backbone(tokens, valid_length)
start_logits = self.get_start_logits(F, contextual_embeddings, p_mask)
end_logits = self.get_end_logits(F, contextual_embeddings,
F.np.expand_dims(start_position, axis=1),
Expand Down Expand Up @@ -257,7 +288,10 @@ def inference(self, tokens, token_types, valid_length, p_mask,
Shape (batch_size, sequence_length, 2)
"""
# Shape (batch_size, sequence_length, C)
contextual_embeddings = self.backbone(tokens, token_types, valid_length)
if self.use_segmentation:
contextual_embeddings = self.backbone(tokens, token_types, valid_length)
else:
contextual_embeddings = self.backbone(tokens, valid_length)
start_logits = self.get_start_logits(mx.nd, contextual_embeddings, p_mask)
# The shape of start_top_index will be (..., start_top_n)
start_top_logits, start_top_index = mx.npx.topk(start_logits, k=start_top_n, axis=-1,
Expand Down
33 changes: 21 additions & 12 deletions scripts/question_answering/run_squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,12 @@ def __init__(self, tokenizer, doc_stride, max_seq_length, max_query_length):
self._max_seq_length = max_seq_length
self._max_query_length = max_query_length

vocab = tokenizer.vocab
self.pad_id = vocab.pad_id
# For roberta model, taking sepecial token <s> as [CLS] and </s> as [SEP]
self.cls_id = vocab.bos_id if 'cls_token' not in vocab.special_token_keys else vocab.cls_id
self.sep_id = vocab.eos_id if 'sep_token' not in vocab.special_token_keys else vocab.sep_id

def process_sample(self, feature: SquadFeature):
"""Process the data to the following format.

Expand Down Expand Up @@ -220,10 +226,9 @@ def process_sample(self, feature: SquadFeature):
doc_stride=self._doc_stride,
max_chunk_length=self._max_seq_length - len(truncated_query_ids) - 3)
for chunk in chunks:
data = np.array([self._tokenizer.vocab.cls_id] + truncated_query_ids +
[self._tokenizer.vocab.sep_id] +
data = np.array([self.cls_id] + truncated_query_ids + [self.sep_id] +
feature.context_token_ids[chunk.start:(chunk.start + chunk.length)] +
[self._tokenizer.vocab.sep_id], dtype=np.int32)
[self.sep_id], dtype=np.int32)
valid_length = len(data)
segment_ids = np.array([0] + [0] * len(truncated_query_ids) +
[0] + [1] * chunk.length + [1], dtype=np.int32)
Expand Down Expand Up @@ -310,8 +315,10 @@ def get_network(model_name,
cfg
tokenizer
qa_net
use_segmentation
"""
# Create the network
use_segmentation = 'roberta' not in model_name and 'xlmr' not in model_name
Model, cfg, tokenizer, download_params_path, _ = \
get_backbone(model_name, load_backbone=not backbone_path)
backbone = Model.from_cfg(cfg, use_pooler=False, dtype=dtype)
Expand All @@ -328,6 +335,7 @@ def get_network(model_name,
backbone_params_path, num_params, num_fixed_params))
qa_net = ModelForQAConditionalV1(backbone=backbone,
dropout_prob=dropout,
use_segmentation=use_segmentation,
weight_initializer=TruncNorm(stdev=0.02))
if checkpoint_path is None:
# Ignore the UserWarning during initialization,
Expand All @@ -337,7 +345,7 @@ def get_network(model_name,
qa_net.load_parameters(checkpoint_path, ctx=ctx_l, cast_dtype=True)
qa_net.hybridize()

return cfg, tokenizer, qa_net
return cfg, tokenizer, qa_net, use_segmentation


def untune_params(model, untunable_depth, not_included=[]):
Expand Down Expand Up @@ -407,10 +415,11 @@ def apply_layerwise_decay(model, layerwise_decay, not_included=[]):

def train(args):
ctx_l = parse_ctx(args.gpus)
cfg, tokenizer, qa_net = get_network(args.model_name, ctx_l,
args.classifier_dropout,
args.param_checkpoint,
args.backbone_path)
cfg, tokenizer, qa_net, use_segmentation \
= get_network(args.model_name, ctx_l,
args.classifier_dropout,
args.param_checkpoint,
args.backbone_path)
# Load the data
train_examples = get_squad_examples(args.data_dir, segment='train', version=args.version)
logging.info('Load data from {}, Version={}'.format(args.data_dir, args.version))
Expand Down Expand Up @@ -551,7 +560,7 @@ def train(args):
log_sample_num += len(tokens)
epoch_sample_num += len(tokens)
num_samples_per_update += len(tokens)
segment_ids = sample.segment_ids.as_in_ctx(ctx)
segment_ids = sample.segment_ids.as_in_ctx(ctx) if use_segmentation else None
valid_length = sample.valid_length.as_in_ctx(ctx)
p_mask = sample.masks.as_in_ctx(ctx)
gt_start = sample.gt_start.as_in_ctx(ctx)
Expand Down Expand Up @@ -779,7 +788,7 @@ def predict_extended(original_feature,

def evaluate(args, last=True):
ctx_l = parse_ctx(args.gpus)
cfg, tokenizer, qa_net = get_network(
cfg, tokenizer, qa_net, use_segmentation = get_network(
args.model_name, ctx_l, args.classifier_dropout, dtype=args.eval_dtype)
if args.eval_dtype == 'float16':
qa_net.cast('float16')
Expand Down Expand Up @@ -847,13 +856,13 @@ def eval_validation(ckpt_name, best_eval):
tokens = sample.data.as_in_ctx(ctx)
total_num += len(tokens)
log_num += len(tokens)
segment_ids = sample.segment_ids.as_in_ctx(ctx)
segment_ids = sample.segment_ids.as_in_ctx(ctx) if use_segmentation else None
valid_length = sample.valid_length.as_in_ctx(ctx)
p_mask = sample.masks.as_in_ctx(ctx)
p_mask = 1 - p_mask # In the network, we use 1 --> no_mask, 0 --> mask
start_top_logits, start_top_index, end_top_logits, end_top_index, answerable_logits \
= qa_net.inference(tokens, segment_ids, valid_length, p_mask,
args.start_top_n, args.end_top_n)
args.start_top_n, args.end_top_n)
for i, qas_id in enumerate(sample.qas_id):
result = RawResultExtended(qas_id=qas_id,
start_top_logits=start_top_logits[i].asnumpy(),
Expand Down
18 changes: 9 additions & 9 deletions src/gluonnlp/data/tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1273,7 +1273,7 @@ class SentencepieceTokenizer(BaseTokenizerWithVocab):
algorithm.
alpha
A scalar for a smoothing parameter for probability rescaling.
do_lower
lowercase
Whether to convert the input string to lower-case strings
**kwargs

Expand All @@ -1295,7 +1295,7 @@ class SentencepieceTokenizer(BaseTokenizerWithVocab):

def __init__(self, model_path: Optional[str] = None,
vocab: Optional[Union[str, Vocab]] = None,
nbest: int = 0, alpha: float = 0.0, do_lower=False,
nbest: int = 0, alpha: float = 0.0, lowercase=False,
**kwargs):
self._model_path = model_path
sentencepiece = try_import_sentencepiece()
Expand All @@ -1305,7 +1305,7 @@ def __init__(self, model_path: Optional[str] = None,
self._sp_model.load(model_path)
self._nbest = nbest
self._alpha = alpha
self._do_lower = do_lower
self._lowercase = lowercase
self._meta_symbol = u'▁'
sp_model_all_tokens = [self._sp_model.id_to_piece(i) for i in range(len(self._sp_model))]
special_tokens_kv = dict()
Expand Down Expand Up @@ -1387,7 +1387,7 @@ def encode(self, sentences, output_type=str):
is_multi_sentences = isinstance(sentences, list)
if not is_multi_sentences:
sentences = [sentences]
if self._do_lower:
if self._lowercase:
sentences = [sentence.lower() for sentence in sentences]
if output_type is str:
ret = [self._sp_model.sample_encode_as_pieces(sentence, self._nbest, self._alpha)
Expand Down Expand Up @@ -1426,7 +1426,7 @@ def encode_with_offsets(self, sentences, output_type=str):
token_ids = []
offsets = []
for sentence in sentences:
if self._do_lower:
if self._lowercase:
sentence = sentence.lower()
spt = self._spt_cls()
spt.ParseFromString(self._sp_model.SampleEncodeAsSerializedProto(
Expand Down Expand Up @@ -1487,8 +1487,8 @@ def set_vocab(self, vocab):
'SentencepieceTokenizer.')

@property
def do_lower(self):
return self._do_lower
def lowercase(self):
return self._lowercase

def set_subword_regularization(self, nbest, alpha):
self._nbest = nbest
Expand All @@ -1497,11 +1497,11 @@ def set_subword_regularization(self, nbest, alpha):
def __repr__(self):
ret = '{}(\n' \
' model_path = {}\n' \
' do_lower = {}, nbest = {}, alpha = {}\n' \
' lowercase = {}, nbest = {}, alpha = {}\n' \
' vocab = {}\n' \
')'.format(self.__class__.__name__,
os.path.realpath(self._model_path),
self._do_lower, self._nbest, self._alpha,
self._lowercase, self._nbest, self._alpha,
self._vocab)
return ret

Expand Down
Loading