diff --git a/scripts/question_answering/README.md b/scripts/question_answering/README.md
index c7fd556c37..3a24cda2d6 100644
--- a/scripts/question_answering/README.md
+++ b/scripts/question_answering/README.md
@@ -50,7 +50,7 @@ python run_squad.py \
--epochs 3 \
--lr 2e-5 \
--warmup_ratio 0.1 \
- --wd=0.01 \
+ --wd 0.01 \
--max_seq_length 512 \
--max_grad_norm 0.1 \
--overwrite_cache \
@@ -60,9 +60,9 @@ or evaluate SQuAD1.1 based on a SQuAD2.0 fine-tuned checkpoint as
```bash
python run_squad.py \
--model_name ${MODEL_NAME} \
- --data_dir=squad \
- --output_dir=${OUT_DIR} \
- --param_checkpoint=${CKPT_PATH} \
+ --data_dir squad \
+ --output_dir ${OUT_DIR} \
+ --param_checkpoint ${CKPT_PATH} \
--version 2.0 \
--do_eval \
--gpus 0,1,2,3 \
@@ -90,11 +90,35 @@ python run_squad.py \
--lr 3e-4 \
--layerwise_decay 0.8 \
--warmup_ratio 0.1 \
- --wd=0 \
+ --wd 0 \
--max_seq_length 512 \
--max_grad_norm 0.1 \
```
+For RoBERTa and XLMR, we remove 'segment_ids' and replace `[CLS]` and `[SEP]` with
+`` and `` which stand for the beginning and end of sentences respectively in original purpose.
+
+```bash
+VERSION=2.0 # Either 2.0 or 1.1
+MODEL_NAME=fairseq_roberta_large
+
+python run_squad.py \
+ --model_name ${MODEL_NAME} \
+ --data_dir squad \
+ --output_dir fintune_${MODEL_NAME}_squad_${VERSION} \
+ --version ${VERSION} \
+ --do_eval \
+ --do_train \
+ --batch_size 2 \
+ --num_accumulated 6 \
+ --gpus 0,1,2,3 \
+ --epochs 3 \
+ --lr 3e-5 \
+ --warmup_ratio 0.2 \
+ --wd 0.01 \
+ --max_seq_length 512 \
+ --max_grad_norm 0.1 \
+```
### Results
We reproduced the ALBERT model which is released by Google, and fine-tune the the SQuAD with single models. ALBERT Version 2 are pre-trained without the dropout mechanism but with extra training steps compared to the version 1 (see the [original paper](https://arxiv.org/abs/1909.11942) for details).
@@ -102,6 +126,7 @@ We reproduced the ALBERT model which is released by Google, and fine-tune the th
Fine-tuning the listed models with hyper-parameter learning rate 2e-5, epochs 3, warmup ratio 0.1 and max gradient norm 0.1 (as shown in command). Notice that the `batch_size` is set for each GPU and the global batch size is 48 for all experiments, besides that gradient accumulation (`num_accumulated`) is supported in the case of out of memory.
Performance are shown in the table below, in which the SQuAD1.1 are evaluated with SQuAD2.0 checkpoints.
+Notice that the standard metrics of SQuAD are EM and F1. The former is an exact match score between predictions and references, while the latter is a token-level f1 score in which the common tokens are considered as True Positives.
|Reproduced ALBERT Models (F1/EM) | SQuAD 1.1 dev | SQuAD 2.0 dev |
|----------------------------------|---------------|---------------|
@@ -119,7 +144,7 @@ For reference, we've included the results from Google's Original Experiments
|ALBERT xlarge (googleresearch/albert) | 92.9/86.4 | 87.9/84.1 |
|ALBERT xxlarge (googleresearch/albert) | 94.6/89.1 | 89.8/86.9 |
-For BERT and ELECTRA model, the results on SQuAD1.1 and SQuAD2.0 are given as follows.
+For the reset pretrained models, the results on SQuAD1.1 and SQuAD2.0 are given as follows.
| Model Name | SQuAD1.1 dev | SQuAD2.0 dev |
|--------------------------|---------------|--------------|
@@ -128,8 +153,9 @@ For BERT and ELECTRA model, the results on SQuAD1.1 and SQuAD2.0 are given as fo
|ELECTRA small | 85.42/78.95 | 74.44/71.86 |
|ELECTRA base | 92.63/87.34 | 86.34/83.62 |
|ELECTRA large | 94.95/89.94 | 90.59/88.13 |
+|RoBERTa large | 94.58/88.86 | 89.01/85.93 |
-For reference, we have also included the results of Google's original version
+For reference, we have also included the results of original version from Google and Fairseq
| Model Name | SQuAD1.1 dev | SQuAD2.0 dev |
|--------------------------|----------------|---------------|
@@ -138,5 +164,7 @@ For reference, we have also included the results of Google's original version
|Google ELECTRA base | - /75.8 | - /70.1 |
|Google ELECTRA base | - /86.8 | - /83.7 |
|Google ELECTRA large | - /89.7 | - /88.1 |
+|Fairseq RoBERTa large | 94.6/88.9 | 89.4/86.5 |
+
All experiments done on AWS P3.8xlarge (4 x NVIDIA Tesla V100 16 GB)
diff --git a/scripts/question_answering/eval_utils.py b/scripts/question_answering/eval_utils.py
index ef7bb3133b..4f9db4916e 100644
--- a/scripts/question_answering/eval_utils.py
+++ b/scripts/question_answering/eval_utils.py
@@ -51,7 +51,7 @@ def compute_exact(a_gold, a_pred):
def compute_f1(a_gold, a_pred):
"""
Compute the token-level f1 scores in which the common tokens are considered
- as True Postives. Precision and recall are percentages of the number of
+ as True Positives. Precision and recall are percentages of the number of
common tokens in the prediction and groud truth, respectively.
"""
gold_toks = get_tokens(a_gold)
diff --git a/scripts/question_answering/models.py b/scripts/question_answering/models.py
index 641247e937..cb85cb7abb 100644
--- a/scripts/question_answering/models.py
+++ b/scripts/question_answering/models.py
@@ -13,10 +13,23 @@ class ModelForQABasic(HybridBlock):
Here, we directly use the backbone network to extract the contextual embeddings and use
another dense layer to map the contextual embeddings to the start scores and end scores.
+ use_segmentation is used to mark whether we segment the input sentence. In RoBERTa and XLMR,
+ this flag is set to True, then the QA model no longer accept `token_types` as valid input.
+
+ - use_segmentation=True:
+ tokens : Question Context
+ token_typess: 0 0 0 1 1
+
+ - use_segmentation=False:
+ tokens : Question Context
+ token_typess: None
"""
- def __init__(self, backbone, weight_initializer=None, bias_initializer=None):
+ def __init__(self, backbone, weight_initializer=None, bias_initializer=None,
+ use_segmentation=True):
super().__init__()
+
self.backbone = backbone
+ self.use_segmentation = use_segmentation
self.qa_outputs = nn.Dense(units=2, flatten=False,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer)
@@ -50,8 +63,11 @@ def hybrid_forward(self, F, tokens, token_types, valid_length, p_mask):
The log-softmax scores that the position is the end position.
"""
# Get contextual embedding with the shape (batch_size, sequence_length, C)
- contextual_embedding = self.backbone(tokens, token_types, valid_length)
- scores = self.qa_outputs(contextual_embedding)
+ if self.use_segmentation:
+ contextual_embeddings = self.backbone(tokens, token_types, valid_length)
+ else:
+ contextual_embeddings = self.backbone(tokens, valid_length)
+ scores = self.qa_outputs(contextual_embeddings)
start_scores = scores[:, :, 0]
end_scores = scores[:, :, 1]
start_logits = masked_logsoftmax(F, start_scores, mask=p_mask, axis=-1)
@@ -72,11 +88,23 @@ class ModelForQAConditionalV1(HybridBlock):
In the inference phase, we are able to use beam search to do the inference.
+ use_segmentation is used to mark whether we segment the input sentence. In RoBERTa and XLMR,
+ this flag is set to True, then the QA model no longer accept `token_types` as valid input.
+
+ - use_segmentation=True:
+ tokens : Question Context
+ token_typess: 0 0 0 1 1
+
+ - use_segmentation=False:
+ tokens : Question Context
+ token_typess: None
"""
def __init__(self, backbone, units=768, layer_norm_eps=1E-12, dropout_prob=0.1,
- activation='tanh', weight_initializer=None, bias_initializer=None):
+ activation='tanh', weight_initializer=None, bias_initializer=None,
+ use_segmentation=True):
super().__init__()
self.backbone = backbone
+ self.use_segmentation = use_segmentation
self.start_scores = nn.Dense(1, flatten=False,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer)
@@ -209,7 +237,10 @@ def hybrid_forward(self, F, tokens, token_types, valid_length, p_mask, start_pos
Shape (batch_size, sequence_length)
answerable_logits
"""
- contextual_embeddings = self.backbone(tokens, token_types, valid_length)
+ if self.use_segmentation:
+ contextual_embeddings = self.backbone(tokens, token_types, valid_length)
+ else:
+ contextual_embeddings = self.backbone(tokens, valid_length)
start_logits = self.get_start_logits(F, contextual_embeddings, p_mask)
end_logits = self.get_end_logits(F, contextual_embeddings,
F.np.expand_dims(start_position, axis=1),
@@ -257,7 +288,10 @@ def inference(self, tokens, token_types, valid_length, p_mask,
Shape (batch_size, sequence_length, 2)
"""
# Shape (batch_size, sequence_length, C)
- contextual_embeddings = self.backbone(tokens, token_types, valid_length)
+ if self.use_segmentation:
+ contextual_embeddings = self.backbone(tokens, token_types, valid_length)
+ else:
+ contextual_embeddings = self.backbone(tokens, valid_length)
start_logits = self.get_start_logits(mx.nd, contextual_embeddings, p_mask)
# The shape of start_top_index will be (..., start_top_n)
start_top_logits, start_top_index = mx.npx.topk(start_logits, k=start_top_n, axis=-1,
diff --git a/scripts/question_answering/run_squad.py b/scripts/question_answering/run_squad.py
index c363225e19..820aec0c46 100644
--- a/scripts/question_answering/run_squad.py
+++ b/scripts/question_answering/run_squad.py
@@ -171,6 +171,12 @@ def __init__(self, tokenizer, doc_stride, max_seq_length, max_query_length):
self._max_seq_length = max_seq_length
self._max_query_length = max_query_length
+ vocab = tokenizer.vocab
+ self.pad_id = vocab.pad_id
+ # For roberta model, taking sepecial token as [CLS] and as [SEP]
+ self.cls_id = vocab.bos_id if 'cls_token' not in vocab.special_token_keys else vocab.cls_id
+ self.sep_id = vocab.eos_id if 'sep_token' not in vocab.special_token_keys else vocab.sep_id
+
def process_sample(self, feature: SquadFeature):
"""Process the data to the following format.
@@ -220,10 +226,9 @@ def process_sample(self, feature: SquadFeature):
doc_stride=self._doc_stride,
max_chunk_length=self._max_seq_length - len(truncated_query_ids) - 3)
for chunk in chunks:
- data = np.array([self._tokenizer.vocab.cls_id] + truncated_query_ids +
- [self._tokenizer.vocab.sep_id] +
+ data = np.array([self.cls_id] + truncated_query_ids + [self.sep_id] +
feature.context_token_ids[chunk.start:(chunk.start + chunk.length)] +
- [self._tokenizer.vocab.sep_id], dtype=np.int32)
+ [self.sep_id], dtype=np.int32)
valid_length = len(data)
segment_ids = np.array([0] + [0] * len(truncated_query_ids) +
[0] + [1] * chunk.length + [1], dtype=np.int32)
@@ -310,8 +315,10 @@ def get_network(model_name,
cfg
tokenizer
qa_net
+ use_segmentation
"""
# Create the network
+ use_segmentation = 'roberta' not in model_name and 'xlmr' not in model_name
Model, cfg, tokenizer, download_params_path, _ = \
get_backbone(model_name, load_backbone=not backbone_path)
backbone = Model.from_cfg(cfg, use_pooler=False, dtype=dtype)
@@ -328,6 +335,7 @@ def get_network(model_name,
backbone_params_path, num_params, num_fixed_params))
qa_net = ModelForQAConditionalV1(backbone=backbone,
dropout_prob=dropout,
+ use_segmentation=use_segmentation,
weight_initializer=TruncNorm(stdev=0.02))
if checkpoint_path is None:
# Ignore the UserWarning during initialization,
@@ -337,7 +345,7 @@ def get_network(model_name,
qa_net.load_parameters(checkpoint_path, ctx=ctx_l, cast_dtype=True)
qa_net.hybridize()
- return cfg, tokenizer, qa_net
+ return cfg, tokenizer, qa_net, use_segmentation
def untune_params(model, untunable_depth, not_included=[]):
@@ -407,10 +415,11 @@ def apply_layerwise_decay(model, layerwise_decay, not_included=[]):
def train(args):
ctx_l = parse_ctx(args.gpus)
- cfg, tokenizer, qa_net = get_network(args.model_name, ctx_l,
- args.classifier_dropout,
- args.param_checkpoint,
- args.backbone_path)
+ cfg, tokenizer, qa_net, use_segmentation \
+ = get_network(args.model_name, ctx_l,
+ args.classifier_dropout,
+ args.param_checkpoint,
+ args.backbone_path)
# Load the data
train_examples = get_squad_examples(args.data_dir, segment='train', version=args.version)
logging.info('Load data from {}, Version={}'.format(args.data_dir, args.version))
@@ -551,7 +560,7 @@ def train(args):
log_sample_num += len(tokens)
epoch_sample_num += len(tokens)
num_samples_per_update += len(tokens)
- segment_ids = sample.segment_ids.as_in_ctx(ctx)
+ segment_ids = sample.segment_ids.as_in_ctx(ctx) if use_segmentation else None
valid_length = sample.valid_length.as_in_ctx(ctx)
p_mask = sample.masks.as_in_ctx(ctx)
gt_start = sample.gt_start.as_in_ctx(ctx)
@@ -779,7 +788,7 @@ def predict_extended(original_feature,
def evaluate(args, last=True):
ctx_l = parse_ctx(args.gpus)
- cfg, tokenizer, qa_net = get_network(
+ cfg, tokenizer, qa_net, use_segmentation = get_network(
args.model_name, ctx_l, args.classifier_dropout, dtype=args.eval_dtype)
if args.eval_dtype == 'float16':
qa_net.cast('float16')
@@ -847,13 +856,13 @@ def eval_validation(ckpt_name, best_eval):
tokens = sample.data.as_in_ctx(ctx)
total_num += len(tokens)
log_num += len(tokens)
- segment_ids = sample.segment_ids.as_in_ctx(ctx)
+ segment_ids = sample.segment_ids.as_in_ctx(ctx) if use_segmentation else None
valid_length = sample.valid_length.as_in_ctx(ctx)
p_mask = sample.masks.as_in_ctx(ctx)
p_mask = 1 - p_mask # In the network, we use 1 --> no_mask, 0 --> mask
start_top_logits, start_top_index, end_top_logits, end_top_index, answerable_logits \
= qa_net.inference(tokens, segment_ids, valid_length, p_mask,
- args.start_top_n, args.end_top_n)
+ args.start_top_n, args.end_top_n)
for i, qas_id in enumerate(sample.qas_id):
result = RawResultExtended(qas_id=qas_id,
start_top_logits=start_top_logits[i].asnumpy(),
diff --git a/src/gluonnlp/data/tokenizers.py b/src/gluonnlp/data/tokenizers.py
index 20a44c02dc..a7aa40ee7b 100644
--- a/src/gluonnlp/data/tokenizers.py
+++ b/src/gluonnlp/data/tokenizers.py
@@ -1273,7 +1273,7 @@ class SentencepieceTokenizer(BaseTokenizerWithVocab):
algorithm.
alpha
A scalar for a smoothing parameter for probability rescaling.
- do_lower
+ lowercase
Whether to convert the input string to lower-case strings
**kwargs
@@ -1295,7 +1295,7 @@ class SentencepieceTokenizer(BaseTokenizerWithVocab):
def __init__(self, model_path: Optional[str] = None,
vocab: Optional[Union[str, Vocab]] = None,
- nbest: int = 0, alpha: float = 0.0, do_lower=False,
+ nbest: int = 0, alpha: float = 0.0, lowercase=False,
**kwargs):
self._model_path = model_path
sentencepiece = try_import_sentencepiece()
@@ -1305,7 +1305,7 @@ def __init__(self, model_path: Optional[str] = None,
self._sp_model.load(model_path)
self._nbest = nbest
self._alpha = alpha
- self._do_lower = do_lower
+ self._lowercase = lowercase
self._meta_symbol = u'▁'
sp_model_all_tokens = [self._sp_model.id_to_piece(i) for i in range(len(self._sp_model))]
special_tokens_kv = dict()
@@ -1387,7 +1387,7 @@ def encode(self, sentences, output_type=str):
is_multi_sentences = isinstance(sentences, list)
if not is_multi_sentences:
sentences = [sentences]
- if self._do_lower:
+ if self._lowercase:
sentences = [sentence.lower() for sentence in sentences]
if output_type is str:
ret = [self._sp_model.sample_encode_as_pieces(sentence, self._nbest, self._alpha)
@@ -1426,7 +1426,7 @@ def encode_with_offsets(self, sentences, output_type=str):
token_ids = []
offsets = []
for sentence in sentences:
- if self._do_lower:
+ if self._lowercase:
sentence = sentence.lower()
spt = self._spt_cls()
spt.ParseFromString(self._sp_model.SampleEncodeAsSerializedProto(
@@ -1487,8 +1487,8 @@ def set_vocab(self, vocab):
'SentencepieceTokenizer.')
@property
- def do_lower(self):
- return self._do_lower
+ def lowercase(self):
+ return self._lowercase
def set_subword_regularization(self, nbest, alpha):
self._nbest = nbest
@@ -1497,11 +1497,11 @@ def set_subword_regularization(self, nbest, alpha):
def __repr__(self):
ret = '{}(\n' \
' model_path = {}\n' \
- ' do_lower = {}, nbest = {}, alpha = {}\n' \
+ ' lowercase = {}, nbest = {}, alpha = {}\n' \
' vocab = {}\n' \
')'.format(self.__class__.__name__,
os.path.realpath(self._model_path),
- self._do_lower, self._nbest, self._alpha,
+ self._lowercase, self._nbest, self._alpha,
self._vocab)
return ret
diff --git a/src/gluonnlp/models/albert.py b/src/gluonnlp/models/albert.py
index 8ad154a96c..1eb504c643 100644
--- a/src/gluonnlp/models/albert.py
+++ b/src/gluonnlp/models/albert.py
@@ -52,6 +52,7 @@
'vocab': 'google_albert_base_v2/vocab-2ee53ae7.json',
'params': 'google_albert_base_v2/model-125be477.params',
'mlm_params': 'google_albert_base_v2/model_mlm-fe20650e.params',
+ 'lowercase': True,
},
'google_albert_large_v2': {
'cfg': 'google_albert_large_v2/model-e2e9b974.yml',
@@ -59,6 +60,7 @@
'vocab': 'google_albert_large_v2/vocab-2ee53ae7.json',
'params': 'google_albert_large_v2/model-ad60bcd5.params',
'mlm_params': 'google_albert_large_v2/model_mlm-6a5015ee.params',
+ 'lowercase': True,
},
'google_albert_xlarge_v2': {
'cfg': 'google_albert_xlarge_v2/model-8123bffd.yml',
@@ -66,6 +68,7 @@
'vocab': 'google_albert_xlarge_v2/vocab-2ee53ae7.json',
'params': 'google_albert_xlarge_v2/model-4149c9e2.params',
'mlm_params': 'google_albert_xlarge_v2/model_mlm-ee184d38.params',
+ 'lowercase': True,
},
'google_albert_xxlarge_v2': {
'cfg': 'google_albert_xxlarge_v2/model-07fbeebc.yml',
@@ -73,6 +76,7 @@
'vocab': 'google_albert_xxlarge_v2/vocab-2ee53ae7.json',
'params': 'google_albert_xxlarge_v2/model-5601a0ed.params',
'mlm_params': 'google_albert_xxlarge_v2/model_mlm-d2e2b06f.params',
+ 'lowercase': True,
},
}
@@ -373,7 +377,7 @@ def get_cfg(key=None):
return cfg
@classmethod
- def from_cfg(cls, cfg, use_pooler=True) -> 'AlbertModel':
+ def from_cfg(cls, cfg, use_pooler=True, dtype='float32') -> 'AlbertModel':
"""
Parameters
@@ -406,7 +410,7 @@ def from_cfg(cls, cfg, use_pooler=True) -> 'AlbertModel':
pos_embed_type=cfg.MODEL.pos_embed_type,
activation=cfg.MODEL.activation,
layer_norm_eps=cfg.MODEL.layer_norm_eps,
- dtype=cfg.MODEL.dtype,
+ dtype=dtype,
embed_initializer=embed_initializer,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
@@ -621,10 +625,11 @@ def get_pretrained_albert(model_name: str = 'google_albert_base_v2',
sha1_hash=FILE_STATS[mlm_params_path])
else:
local_mlm_params_path = None
- # TODO(sxjscience) Move do_lower to assets.
+ do_lower = True if 'lowercase' in PRETRAINED_URL[model_name]\
+ and PRETRAINED_URL[model_name]['lowercase'] else False
tokenizer = SentencepieceTokenizer(local_paths['spm_model'],
vocab=local_paths['vocab'],
- do_lower=True)
+ lowercase=do_lower)
cfg = AlbertModel.get_cfg().clone_merge(local_paths['cfg'])
return cfg, tokenizer, local_params_path, local_mlm_params_path
diff --git a/src/gluonnlp/models/bert.py b/src/gluonnlp/models/bert.py
index a65244f8c9..194b239b4b 100644
--- a/src/gluonnlp/models/bert.py
+++ b/src/gluonnlp/models/bert.py
@@ -52,6 +52,7 @@
'vocab': 'google_en_cased_bert_base/vocab-c1defaaa.json',
'params': 'google_en_cased_bert_base/model-c566c289.params',
'mlm_params': 'google_en_cased_bert_base/model_mlm-c3ff36a3.params',
+ 'lowercase': False,
},
'google_en_uncased_bert_base': {
@@ -66,6 +67,7 @@
'vocab': 'google_en_cased_bert_large/vocab-c1defaaa.json',
'params': 'google_en_cased_bert_large/model-7aa93704.params',
'mlm_params': 'google_en_cased_bert_large/model_mlm-d6443fe9.params',
+ 'lowercase': False,
},
'google_en_uncased_bert_large': {
'cfg': 'google_en_uncased_bert_large/model-d0c37dcc.yml',
@@ -79,18 +81,21 @@
'vocab': 'google_zh_bert_base/vocab-711c13e4.json',
'params': 'google_zh_bert_base/model-2efbff63.params',
'mlm_params': 'google_zh_bert_base/model_mlm-75339658.params',
+ 'lowercase': False,
},
'google_multi_cased_bert_base': {
'cfg': 'google_multi_cased_bert_base/model-881ad607.yml',
'vocab': 'google_multi_cased_bert_base/vocab-016e1169.json',
'params': 'google_multi_cased_bert_base/model-c2110078.params',
'mlm_params': 'google_multi_cased_bert_base/model_mlm-4611e7a3.params',
+ 'lowercase': False,
},
'google_en_cased_bert_wwm_large': {
'cfg': 'google_en_cased_bert_wwm_large/model-9e127fee.yml',
'vocab': 'google_en_cased_bert_wwm_large/vocab-c1defaaa.json',
'params': 'google_en_cased_bert_wwm_large/model-0fe841cf.params',
'mlm_params': None,
+ 'lowercase': False,
},
'google_en_uncased_bert_wwm_large': {
'cfg': 'google_en_uncased_bert_wwm_large/model-d0c37dcc.yml',
@@ -381,7 +386,7 @@ def get_cfg(key=None):
return cfg
@classmethod
- def from_cfg(cls, cfg, use_pooler=True, dtype='float32'):
+ def from_cfg(cls, cfg, use_pooler=True, dtype='float32') -> 'BertModel':
cfg = BertModel.get_cfg().clone_merge(cfg)
assert cfg.VERSION == 1, 'Wrong version!'
embed_initializer = mx.init.create(*cfg.INITIALIZER.embed)
@@ -616,7 +621,6 @@ def get_pretrained_bert(model_name: str = 'google_en_cased_bert_base',
local_mlm_params_path = None
do_lower = True if 'lowercase' in PRETRAINED_URL[model_name]\
and PRETRAINED_URL[model_name]['lowercase'] else False
- # TODO(sxjscience) Move do_lower to assets.
tokenizer = HuggingFaceWordPieceTokenizer(
vocab_file=local_paths['vocab'],
unk_token='[UNK]',
diff --git a/src/gluonnlp/models/electra.py b/src/gluonnlp/models/electra.py
index 7f9512d915..f82b945d3f 100644
--- a/src/gluonnlp/models/electra.py
+++ b/src/gluonnlp/models/electra.py
@@ -73,6 +73,7 @@ def get_generator_cfg(model_config):
'params': 'google_electra_small/model-2654c8b4.params',
'disc_model': 'google_electra_small/disc_model-137714b6.params',
'gen_model': 'google_electra_small/gen_model-d11fd0b1.params',
+ 'lowercase': True,
},
'google_electra_base': {
'cfg': 'google_electra_base/model-5b35ca0b.yml',
@@ -80,6 +81,7 @@ def get_generator_cfg(model_config):
'params': 'google_electra_base/model-31c235cc.params',
'disc_model': 'google_electra_base/disc_model-514bd353.params',
'gen_model': 'google_electra_base/gen_model-665ce594.params',
+ 'lowercase': True,
},
'google_electra_large': {
'cfg': 'google_electra_large/model-31b7dfdd.yml',
@@ -87,6 +89,7 @@ def get_generator_cfg(model_config):
'params': 'google_electra_large/model-9baf9ff5.params',
'disc_model': 'google_electra_large/disc_model-5b820c02.params',
'gen_model': 'google_electra_large/gen_model-667121df.params',
+ 'lowercase': True,
}
}
@@ -372,7 +375,7 @@ def get_cfg(key=None):
return cfg
@classmethod
- def from_cfg(cls, cfg, use_pooler=True):
+ def from_cfg(cls, cfg, use_pooler=True, dtype='float32') -> 'ElectraModel':
cfg = ElectraModel.get_cfg().clone_merge(cfg)
assert cfg.VERSION == 1, 'Wrong version!'
embed_initializer = mx.init.create(*cfg.INITIALIZER.embed)
@@ -391,7 +394,7 @@ def from_cfg(cls, cfg, use_pooler=True):
pos_embed_type=cfg.MODEL.pos_embed_type,
activation=cfg.MODEL.activation,
layer_norm_eps=cfg.MODEL.layer_norm_eps,
- dtype=cfg.MODEL.dtype,
+ dtype=dtype,
embed_initializer=embed_initializer,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
@@ -813,7 +816,9 @@ def get_pretrained_electra(model_name: str = 'google_electra_small',
sha1_hash=FILE_STATS[gen_params_path])
else:
local_gen_params_path = None
- # TODO(sxjscience) Move do_lower to assets.
+
+ do_lower = True if 'lowercase' in PRETRAINED_URL[model_name]\
+ and PRETRAINED_URL[model_name]['lowercase'] else False
tokenizer = HuggingFaceWordPieceTokenizer(
vocab_file=local_paths['vocab'],
unk_token='[UNK]',
@@ -821,7 +826,7 @@ def get_pretrained_electra(model_name: str = 'google_electra_small',
cls_token='[CLS]',
sep_token='[SEP]',
mask_token='[MASK]',
- lowercase=True)
+ lowercase=do_lower)
cfg = ElectraModel.get_cfg().clone_merge(local_paths['cfg'])
return cfg, tokenizer, local_params_path, (local_disc_params_path, local_gen_params_path)
diff --git a/src/gluonnlp/models/mobilebert.py b/src/gluonnlp/models/mobilebert.py
index 65a862d1a3..502d7f4750 100644
--- a/src/gluonnlp/models/mobilebert.py
+++ b/src/gluonnlp/models/mobilebert.py
@@ -54,6 +54,7 @@
'vocab': 'google_uncased_mobilebert/vocab-e6d2b21d.json',
'params': 'google_uncased_mobilebert/model-c8346cf2.params',
'mlm_params': 'google_uncased_mobilebert/model_mlm-53948e82.params',
+ 'lowercase': True,
}
}
@@ -614,10 +615,11 @@ def get_cfg(key=None):
@classmethod
def from_cfg(cls,
cfg,
+ use_pooler=True,
+ dtype='float32',
use_bottleneck=True,
trigram_embed=True,
- use_pooler=True,
- classifier_activation=False):
+ classifier_activation=False) -> 'MobileBertModel':
cfg = MobileBertModel.get_cfg().clone_merge(cfg)
assert cfg.VERSION == 1, 'Wrong version!'
embed_initializer = mx.init.create(*cfg.INITIALIZER.embed)
@@ -640,7 +642,7 @@ def from_cfg(cls,
activation=cfg.MODEL.activation,
normalization=cfg.MODEL.normalization,
layer_norm_eps=cfg.MODEL.layer_norm_eps,
- dtype=cfg.MODEL.dtype,
+ dtype=dtype,
embed_initializer=embed_initializer,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
@@ -904,7 +906,6 @@ def get_pretrained_mobilebert(model_name: str = 'google_uncased_mobilebert',
local_mlm_params_path = None
do_lower = True if 'lowercase' in PRETRAINED_URL[model_name]\
and PRETRAINED_URL[model_name]['lowercase'] else False
- # TODO(sxjscience) Move do_lower to assets.
tokenizer = HuggingFaceWordPieceTokenizer(
vocab_file=local_paths['vocab'],
unk_token='[UNK]',
diff --git a/src/gluonnlp/models/model_zoo_checksums/roberta.txt b/src/gluonnlp/models/model_zoo_checksums/roberta.txt
index 4e4f9efe6d..6de6e8ce5f 100644
--- a/src/gluonnlp/models/model_zoo_checksums/roberta.txt
+++ b/src/gluonnlp/models/model_zoo_checksums/roberta.txt
@@ -1,12 +1,10 @@
-fairseq_roberta_base/model-565d1db7.yml 565d1db71b0452fa2c28f155b8e9d90754f4f40a 401
-fairseq_roberta_base/gpt2-396d4d8e.merges 396d4d8ec90cb02f4d56e049e0e4add868bcd943 456318
-fairseq_roberta_base/gpt2-f1335494.vocab f1335494f47917829e3b1d08e579ff2c3fe4fd60 558231
-fairseq_roberta_base/model-98b4532f.params 98b4532fe59e6fd755422057fde4601b3eb8fbf0 498792661
-fairseq_roberta_large/model-6e66dc4a.yml 6e66dc4a450560a93aaf3d0ba9e0d447495d778a 402
-fairseq_roberta_large/gpt2-396d4d8e.merges 396d4d8ec90cb02f4d56e049e0e4add868bcd943 456318
-fairseq_roberta_large/gpt2-f1335494.vocab f1335494f47917829e3b1d08e579ff2c3fe4fd60 558231
-fairseq_roberta_large/model-e3f578dc.params e3f578dc669cf36fa5b6730b0bbee77c980276d7 1421659773
-fairseq_roberta_large_mnli/model-6e66dc4a.yml 6e66dc4a450560a93aaf3d0ba9e0d447495d778a 402
-fairseq_roberta_large_mnli/gpt2-396d4d8e.merges 396d4d8ec90cb02f4d56e049e0e4add868bcd943 456318
-fairseq_roberta_large_mnli/gpt2-f1335494.vocab f1335494f47917829e3b1d08e579ff2c3fe4fd60 558231
-fairseq_roberta_large_mnli/model-5288bb09.params 5288bb09db89b7900e85c9d673686f748f0abd56 1421659773
\ No newline at end of file
+fairseq_roberta_base/model-565d1db7.yml 565d1db71b0452fa2c28f155b8e9d90754f4f40a 401
+fairseq_roberta_base/gpt2-396d4d8e.merges 396d4d8ec90cb02f4d56e049e0e4add868bcd943 456318
+fairseq_roberta_base/gpt2-f1335494.vocab f1335494f47917829e3b1d08e579ff2c3fe4fd60 558231
+fairseq_roberta_base/model-09a1520a.params 09a1520adf652468c07e43a6ed28908418fa58a7 496222787
+fairseq_roberta_base/model_mlm-29889e2b.params 29889e2b4ef20676fda117bb7b754e1693d0df25 498794868
+fairseq_roberta_large/model-6b043b91.params 6b043b91a6a781a12ea643d0644d32300db38ec8 1417251819
+fairseq_roberta_large/gpt2-396d4d8e.merges 396d4d8ec90cb02f4d56e049e0e4add868bcd943 456318
+fairseq_roberta_large/model-6e66dc4a.yml 6e66dc4a450560a93aaf3d0ba9e0d447495d778a 402
+fairseq_roberta_large/gpt2-f1335494.vocab f1335494f47917829e3b1d08e579ff2c3fe4fd60 558231
+fairseq_roberta_large/model_mlm-119f38e1.params 119f38e1249bd28bea7dd2e90c09b8f4b879fa19 1421664140
diff --git a/src/gluonnlp/models/model_zoo_checksums/xlmr.txt b/src/gluonnlp/models/model_zoo_checksums/xlmr.txt
index 94a81c5cf9..355584fcd5 100644
--- a/src/gluonnlp/models/model_zoo_checksums/xlmr.txt
+++ b/src/gluonnlp/models/model_zoo_checksums/xlmr.txt
@@ -1,6 +1,8 @@
-fairseq_xlmr_base/model-b893d178.yml b893d178fa859fb6c708a08fc970b9980e047825 402
-fairseq_xlmr_base/sentencepiece-18e17bae.model 18e17bae37be115135d4cf4ad9dfcc4f3b12cb80 5069075
-fairseq_xlmr_base/model-340f4fa8.params 340f4fa8e086ac5f57a59c999a47d7efa343f900 1113183673
-fairseq_xlmr_large/model-01fc59fb.yml 01fc59fb3a805f09d2aa11369d5b57e0be931fdd 403
-fairseq_xlmr_large/sentencepiece-18e17bae.model 18e17bae37be115135d4cf4ad9dfcc4f3b12cb80 5069075
-fairseq_xlmr_large/model-e4b11125.params e4b11125312a54a130e87b480e37ff43a2ad0f2d 2240581473
\ No newline at end of file
+fairseq_xlmr_base/model-3fa134e9.params 3fa134e9a13e2329ffa7b8d39612695ed8397c9d 1109814851
+fairseq_xlmr_base/model-b893d178.yml b893d178fa859fb6c708a08fc970b9980e047825 402
+fairseq_xlmr_base/model_mlm-86e37954.params 86e379542a6430cd988ff4b6a25966949afc241a 1113185880
+fairseq_xlmr_base/sentencepiece-18e17bae.model 18e17bae37be115135d4cf4ad9dfcc4f3b12cb80 5069075
+fairseq_xlmr_large/model-b62b074c.params b62b074cdd41e682075e2407f842be6578696b26 2235374571
+fairseq_xlmr_large/model-01fc59fb.yml 01fc59fb3a805f09d2aa11369d5b57e0be931fdd 403
+fairseq_xlmr_large/model_mlm-887506c2.params 887506c20bda452cf13ef04390eaa57a55602a92 2240585840
+fairseq_xlmr_large/sentencepiece-18e17bae.model 18e17bae37be115135d4cf4ad9dfcc4f3b12cb80 5069075
diff --git a/src/gluonnlp/models/roberta.py b/src/gluonnlp/models/roberta.py
index 44a1a20c1e..8400f89fbd 100644
--- a/src/gluonnlp/models/roberta.py
+++ b/src/gluonnlp/models/roberta.py
@@ -27,22 +27,26 @@
}
"""
-__all__ = ['RobertaModel', 'list_pretrained_roberta', 'get_pretrained_roberta']
+__all__ = ['RobertaModel', 'RobertaForMLM', 'list_pretrained_roberta', 'get_pretrained_roberta']
-from typing import Tuple
import os
+from typing import Tuple
+
import mxnet as mx
from mxnet import use_np
-from mxnet.gluon import nn, HybridBlock
+from mxnet.gluon import HybridBlock, nn
+
+from ..op import select_vectors_by_position
+from ..base import get_model_zoo_home_dir, get_repo_model_zoo_url, \
+ get_model_zoo_checksum_dir
+from ..layers import PositionalEmbedding, get_activation
+from ..registry import BACKBONE_REGISTRY
+from ..utils.misc import download, load_checksum_stats
from .transformer import TransformerEncoderLayer
-from ..base import get_model_zoo_home_dir, get_repo_model_zoo_url, get_model_zoo_checksum_dir
-from ..utils.config import CfgNode as CN
-from ..utils.registry import Registry
-from ..utils.misc import load_checksum_stats, download
from ..initializer import TruncNorm
+from ..utils.config import CfgNode as CN
from ..attention_cell import gen_self_attn_mask
-from ..registry import BACKBONE_REGISTRY
-from ..layers import PositionalEmbedding, get_activation
+from ..utils.registry import Registry
from ..data.tokenizers import HuggingFaceByteBPETokenizer
PRETRAINED_URL = {
@@ -50,19 +54,17 @@
'cfg': 'fairseq_roberta_base/model-565d1db7.yml',
'merges': 'fairseq_roberta_base/gpt2-396d4d8e.merges',
'vocab': 'fairseq_roberta_base/gpt2-f1335494.vocab',
- 'params': 'fairseq_roberta_base/model-98b4532f.params'
+ 'params': 'fairseq_roberta_base/model-09a1520a.params',
+ 'mlm_params': 'fairseq_roberta_base/model_mlm-29889e2b.params',
+ 'lowercase': False,
},
'fairseq_roberta_large': {
'cfg': 'fairseq_roberta_large/model-6e66dc4a.yml',
'merges': 'fairseq_roberta_large/gpt2-396d4d8e.merges',
'vocab': 'fairseq_roberta_large/gpt2-f1335494.vocab',
- 'params': 'fairseq_roberta_large/model-e3f578dc.params'
- },
- 'fairseq_roberta_large_mnli': {
- 'cfg': 'fairseq_roberta_large_mnli/model-6e66dc4a.yml',
- 'merges': 'fairseq_roberta_large_mnli/gpt2-396d4d8e.merges',
- 'vocab': 'fairseq_roberta_large_mnli/gpt2-f1335494.vocab',
- 'params': 'fairseq_roberta_large_mnli/model-5288bb09.params'
+ 'params': 'fairseq_roberta_large/model-6b043b91.params',
+ 'mlm_params': 'fairseq_roberta_large/model_mlm-119f38e1.params',
+ 'lowercase': False,
}
}
@@ -96,6 +98,7 @@ def roberta_base():
cfg.freeze()
return cfg
+
@roberta_cfg_reg.register()
def roberta_large():
cfg = roberta_base()
@@ -107,6 +110,7 @@ def roberta_large():
cfg.freeze()
return cfg
+
@use_np
class RobertaModel(HybridBlock):
def __init__(self,
@@ -126,12 +130,11 @@ def __init__(self,
weight_initializer=TruncNorm(stdev=0.02),
bias_initializer='zeros',
dtype='float32',
- use_pooler=False,
- use_mlm=True,
- untie_weight=False,
+ use_pooler=True,
+ classifier_activation=False,
encoder_normalize_before=True,
- return_all_hiddens=False):
- """
+ output_all_encodings=False):
+ """
Parameters
----------
@@ -152,15 +155,15 @@ def __init__(self,
bias_initializer
dtype
use_pooler
+ Whether to output the CLS hidden state
+ classifier_activation
Whether to use classification head
- use_mlm
- Whether to use lm head, if False, forward return hidden states only
- untie_weight
- Whether to untie weights between embeddings and classifiers
encoder_normalize_before
- return_all_hiddens
+ output_all_encodings
"""
super().__init__()
+ self._dtype = dtype
+ self._output_all_encodings = output_all_encodings
self.vocab_size = vocab_size
self.units = units
self.hidden_size = hidden_size
@@ -173,32 +176,28 @@ def __init__(self,
self.activation = activation
self.pooler_activation = pooler_activation
self.layer_norm_eps = layer_norm_eps
- self.dtype = dtype
self.use_pooler = use_pooler
- self.use_mlm = use_mlm
- self.untie_weight = untie_weight
+ self.classifier_activation = classifier_activation
self.encoder_normalize_before = encoder_normalize_before
- self.return_all_hiddens = return_all_hiddens
- self.tokens_embed = nn.Embedding(
+ self.weight_initializer = weight_initializer
+ self.bias_initializer = bias_initializer
+
+ self.word_embed = nn.Embedding(
input_dim=self.vocab_size,
output_dim=self.units,
weight_initializer=embed_initializer,
- dtype=self.dtype,
+ dtype=self._dtype
)
if self.encoder_normalize_before:
self.embed_ln = nn.LayerNorm(
epsilon=self.layer_norm_eps,
- in_channels=self.units,
- )
- else:
- self.embed_ln = None
+ in_channels=self.units)
self.embed_dropout = nn.Dropout(self.hidden_dropout_prob)
self.pos_embed = PositionalEmbedding(
units=self.units,
max_length=self.max_length,
- dtype=self.dtype,
- method=pos_embed_type,
- )
+ dtype=self._dtype,
+ method=pos_embed_type)
self.encoder = RobertaEncoder(
units=self.units,
@@ -211,42 +210,76 @@ def __init__(self,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
activation=self.activation,
- dtype=self.dtype,
- return_all_hiddens=self.return_all_hiddens
+ dtype=self._dtype,
+ output_all_encodings=self._output_all_encodings
)
self.encoder.hybridize()
- if self.use_mlm:
- self.lm_head = RobertaLMHead(
- self.units,
- self.vocab_size,
- self.activation,
- layer_norm_eps=self.layer_norm_eps,
- weight_initializer=weight_initializer,
- bias_initializer=bias_initializer
- )
- if not untie_weight:
- self.lm_head.dense2.weight = self.tokens_embed.weight
- self.lm_head.hybridize()
- # TODO support use_pooler
+ if self.use_pooler and self.classifier_activation:
+ # Construct pooler
+ self.pooler = nn.Dense(units=self.units,
+ in_units=self.units,
+ flatten=False,
+ activation=self.pooler_activation,
+ weight_initializer=weight_initializer,
+ bias_initializer=bias_initializer)
def hybrid_forward(self, F, tokens, valid_length):
- x = self.tokens_embed(tokens)
+ outputs = []
+ embedding = self.get_initial_embedding(F, tokens)
+
+ contextual_embeddings, additional_outputs = self.encoder(embedding, valid_length)
+ outputs.append(contextual_embeddings)
+ if self._output_all_encodings:
+ contextual_embeddings = contextual_embeddings[-1]
+
+ if self.use_pooler:
+ pooled_out = self.apply_pooling(contextual_embeddings)
+ outputs.append(pooled_out)
+
+ return tuple(outputs) if len(outputs) > 1 else outputs[0]
+
+ def get_initial_embedding(self, F, inputs):
+ """Get the initial token embeddings that considers the token type and positional embeddings
+
+ Parameters
+ ----------
+ F
+ inputs
+ Shape (batch_size, seq_length)
+
+ Returns
+ -------
+ embedding
+ The initial embedding that will be fed into the encoder
+ """
+ embedding = self.word_embed(inputs)
if self.pos_embed_type:
- positional_embedding = self.pos_embed(F.npx.arange_like(x, axis=1))
+ positional_embedding = self.pos_embed(F.npx.arange_like(inputs, axis=1))
positional_embedding = F.np.expand_dims(positional_embedding, axis=0)
- x = x + positional_embedding
- if self.embed_ln:
- x = self.embed_ln(x)
- x = self.embed_dropout(x)
- inner_states = self.encoder(x, valid_length)
- x = inner_states[-1]
- if self.use_mlm:
- x = self.lm_head(x)
- if self.return_all_hiddens:
- return x, inner_states
+ embedding = embedding + positional_embedding
+ if self.encoder_normalize_before:
+ embedding = self.embed_ln(embedding)
+ embedding = self.embed_dropout(embedding)
+
+ return embedding
+
+ def apply_pooling(self, sequence):
+ """Generate the representation given the inputs.
+
+ This is used for pre-training or fine-tuning a mobile bert model.
+ Get the first token of the whole sequence which is [CLS]
+
+ sequence:
+ Shape (batch_size, sequence_length, units)
+ return:
+ Shape (batch_size, units)
+ """
+ outputs = sequence[:, 0, :]
+ if self.classifier_activation:
+ return self.pooler(outputs)
else:
- return x
+ return outputs
@staticmethod
def get_cfg(key=None):
@@ -258,11 +291,11 @@ def get_cfg(key=None):
@classmethod
def from_cfg(cls,
cfg,
- use_pooler=False,
- use_mlm=True,
- untie_weight=False,
+ use_pooler=True,
+ dtype='float32',
+ classifier_activation=False,
encoder_normalize_before=True,
- return_all_hiddens=False):
+ output_all_encodings=False) -> 'RobertaModel':
cfg = RobertaModel.get_cfg().clone_merge(cfg)
embed_initializer = mx.init.create(*cfg.INITIALIZER.embed)
weight_initializer = mx.init.create(*cfg.INITIALIZER.weight)
@@ -282,15 +315,14 @@ def from_cfg(cls,
embed_initializer=embed_initializer,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
- dtype=cfg.MODEL.dtype,
+ dtype=dtype,
use_pooler=use_pooler,
- use_mlm=use_mlm,
- untie_weight=untie_weight,
encoder_normalize_before=encoder_normalize_before,
- return_all_hiddens=return_all_hiddens)
+ output_all_encodings=output_all_encodings)
+
@use_np
-class RobertaEncoder(HybridBlock):
+class RobertaEncoder(HybridBlock):
def __init__(self,
units=768,
hidden_size=3072,
@@ -303,7 +335,8 @@ def __init__(self,
bias_initializer='zeros',
activation='gelu',
dtype='float32',
- return_all_hiddens=False):
+ output_all_encodings=False,
+ output_attention=False):
super().__init__()
self.units = units
self.hidden_size = hidden_size
@@ -313,8 +346,9 @@ def __init__(self,
self.hidden_dropout_prob = hidden_dropout_prob
self.layer_norm_eps = layer_norm_eps
self.activation = activation
- self.dtype = dtype
- self.return_all_hiddens = return_all_hiddens
+ self._dtype = dtype
+ self._output_all_encodings = output_all_encodings
+ self._output_attention = output_attention
self.all_layers = nn.HybridSequential()
for layer_idx in range(self.num_layers):
self.all_layers.add(
@@ -328,54 +362,103 @@ def __init__(self,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
activation=self.activation,
- dtype=self.dtype,
- )
+ dtype=self._dtype)
)
def hybrid_forward(self, F, x, valid_length):
atten_mask = gen_self_attn_mask(F, x, valid_length,
- dtype=self.dtype, attn_type='full')
- inner_states = [x]
+ dtype=self._dtype, attn_type='full')
+ all_encodings_outputs = [x]
+ additional_outputs = []
for layer_idx in range(self.num_layers):
layer = self.all_layers[layer_idx]
- x, _ = layer(x, atten_mask)
- inner_states.append(x)
- if not self.return_all_hiddens:
- inner_states = [x]
- return inner_states
+ x, attention_weights = layer(x, atten_mask)
+ if self._output_all_encodings:
+ all_encodings_outputs.append(x)
+ if self._output_attention:
+ additional_outputs.append(attention_weights)
+ # sequence_mask is not necessary here because masking could be performed in downstream tasks
+ if self._output_all_encodings:
+ return all_encodings_outputs, additional_outputs
+ else:
+ return x, additional_outputs
+
@use_np
-class RobertaLMHead(HybridBlock):
- def __init__(self,
- embed_dim=768,
- output_dim=50265,
- activation_fn='gelu',
- layer_norm_eps=1E-5,
- weight_initializer=TruncNorm(stdev=0.02),
- bias_initializer='zeros'):
+class RobertaForMLM(HybridBlock):
+ def __init__(self, backbone_cfg,
+ weight_initializer=None,
+ bias_initializer=None):
+ """
+
+ Parameters
+ ----------
+ backbone_cfg
+ weight_initializer
+ bias_initializer
+ """
super().__init__()
- self.dense1 = nn.Dense(in_units=embed_dim,
- units=embed_dim,
- flatten=False,
- weight_initializer=weight_initializer,
- bias_initializer=bias_initializer)
- self.activation_fn = get_activation(activation_fn)
- self.ln = nn.LayerNorm(
- epsilon=layer_norm_eps,
- in_channels=embed_dim)
- self.dense2 = nn.Dense(in_units=embed_dim,
- units=output_dim,
- activation=None,
- flatten=False,
- weight_initializer=weight_initializer,
- bias_initializer='zeros')
-
- def hybrid_forward(self, F, x):
- x = self.dense1(x)
- x = self.activation_fn(x)
- x = self.ln(x)
- x = self.dense2(x)
- return x
+ self.backbone_model = RobertaModel.from_cfg(backbone_cfg)
+ if weight_initializer is None:
+ weight_initializer = self.backbone_model.weight_initializer
+ if bias_initializer is None:
+ bias_initializer = self.backbone_model.bias_initializer
+ self.units = self.backbone_model.units
+ self.mlm_decoder = nn.HybridSequential()
+ # Extra non-linear layer
+ self.mlm_decoder.add(nn.Dense(units=self.units,
+ in_units=self.units,
+ flatten=False,
+ weight_initializer=weight_initializer,
+ bias_initializer=bias_initializer))
+ self.mlm_decoder.add(get_activation(self.backbone_model.activation))
+ self.mlm_decoder.add(nn.LayerNorm(epsilon=self.backbone_model.layer_norm_eps,
+ in_channels=self.units))
+ # only load the dense weights with a re-initialized bias
+ # parameters are stored in 'word_embed_bias' which is
+ # not used in original embedding
+ self.mlm_decoder.add(
+ nn.Dense(
+ units=self.backbone_model.vocab_size,
+ in_units=self.units,
+ flatten=False,
+ bias_initializer=bias_initializer))
+ self.mlm_decoder[-1].weight = self.backbone_model.word_embed.weight
+ self.mlm_decoder.hybridize()
+
+ def hybrid_forward(self, F, inputs, valid_length, masked_positions):
+ """Getting the scores of the masked positions.
+
+ Parameters
+ ----------
+ F
+ inputs :
+ Shape (batch_size, seq_length)
+ valid_length :
+ The valid length of each sequence
+ Shape (batch_size,)
+ masked_positions :
+ The masked position of the sequence
+ Shape (batch_size, num_masked_positions).
+
+ Returns
+ -------
+ contextual_embedding
+ Shape (batch_size, seq_length, units).
+ pooled_out
+ Shape (batch_size, units)
+ mlm_scores :
+ Shape (batch_size, num_masked_positions, vocab_size)
+ """
+
+ all_encodings_outputs, pooled_out = self.backbone_model(inputs, valid_length)
+ if self.backbone_model._output_all_encodings:
+ contextual_embeddings = all_encodings_outputs[-1]
+ else:
+ contextual_embeddings = all_encodings_outputs
+ mlm_features = select_vectors_by_position(F, contextual_embeddings, masked_positions)
+ mlm_scores = self.mlm_decoder(mlm_features)
+ return all_encodings_outputs, pooled_out, mlm_scores
def list_pretrained_roberta():
@@ -383,7 +466,9 @@ def list_pretrained_roberta():
def get_pretrained_roberta(model_name: str = 'fairseq_roberta_base',
- root: str = get_model_zoo_home_dir()) \
+ root: str = get_model_zoo_home_dir(),
+ load_backbone: bool = True,
+ load_mlm: bool = False) \
-> Tuple[CN, HuggingFaceByteBPETokenizer, str]:
"""Get the pretrained RoBERTa weights
@@ -393,6 +478,10 @@ def get_pretrained_roberta(model_name: str = 'fairseq_roberta_base',
The name of the RoBERTa model.
root
The downloading root
+ load_backbone
+ Whether to load the weights of the backbone network
+ load_mlm
+ Whether to load the weights of MLM
Returns
-------
@@ -402,6 +491,8 @@ def get_pretrained_roberta(model_name: str = 'fairseq_roberta_base',
The HuggingFaceByteBPETokenizer
params_path
Path to the parameters
+ mlm_params_path
+ Path to the parameter that includes both the backbone and the MLM
"""
assert model_name in PRETRAINED_URL, '{} is not found. All available are {}'.format(
model_name, list_pretrained_roberta())
@@ -409,15 +500,34 @@ def get_pretrained_roberta(model_name: str = 'fairseq_roberta_base',
merges_path = PRETRAINED_URL[model_name]['merges']
vocab_path = PRETRAINED_URL[model_name]['vocab']
params_path = PRETRAINED_URL[model_name]['params']
+ mlm_params_path = PRETRAINED_URL[model_name]['mlm_params']
+
local_paths = dict()
for k, path in [('cfg', cfg_path), ('vocab', vocab_path),
- ('merges', merges_path), ('params', params_path)]:
+ ('merges', merges_path)]:
local_paths[k] = download(url=get_repo_model_zoo_url() + path,
path=os.path.join(root, path),
sha1_hash=FILE_STATS[path])
- tokenizer = HuggingFaceByteBPETokenizer(local_paths['merges'], local_paths['vocab'])
+ if load_backbone:
+ local_params_path = download(url=get_repo_model_zoo_url() + params_path,
+ path=os.path.join(root, params_path),
+ sha1_hash=FILE_STATS[params_path])
+ else:
+ local_params_path = None
+ if load_mlm and mlm_params_path is not None:
+ local_mlm_params_path = download(url=get_repo_model_zoo_url() + mlm_params_path,
+ path=os.path.join(root, mlm_params_path),
+ sha1_hash=FILE_STATS[mlm_params_path])
+ else:
+ local_mlm_params_path = None
+ do_lower = True if 'lowercase' in PRETRAINED_URL[model_name]\
+ and PRETRAINED_URL[model_name]['lowercase'] else False
+ tokenizer = HuggingFaceByteBPETokenizer(
+ merges_file=local_paths['merges'],
+ vocab_file=local_paths['vocab'],
+ lowercase=do_lower)
cfg = RobertaModel.get_cfg().clone_merge(local_paths['cfg'])
- return cfg, tokenizer, local_paths['params']
+ return cfg, tokenizer, local_params_path, local_mlm_params_path
BACKBONE_REGISTRY.register('roberta', [RobertaModel,
diff --git a/src/gluonnlp/models/xlmr.py b/src/gluonnlp/models/xlmr.py
index 70ad76c11f..b433d34157 100644
--- a/src/gluonnlp/models/xlmr.py
+++ b/src/gluonnlp/models/xlmr.py
@@ -25,12 +25,12 @@
}
"""
-__all__ = ['XLMRModel', 'list_pretrained_xlmr', 'get_pretrained_xlmr']
+__all__ = ['XLMRModel', 'XLMRForMLM', 'list_pretrained_xlmr', 'get_pretrained_xlmr']
from typing import Tuple
import os
from mxnet import use_np
-from .roberta import RobertaModel, roberta_base, roberta_large
+from .roberta import RobertaModel, RobertaForMLM, roberta_base, roberta_large
from ..base import get_model_zoo_home_dir, get_repo_model_zoo_url, get_model_zoo_checksum_dir
from ..utils.config import CfgNode as CN
from ..utils.registry import Registry
@@ -43,17 +43,21 @@
'fairseq_xlmr_base': {
'cfg': 'fairseq_xlmr_base/model-b893d178.yml',
'sentencepiece.model': 'fairseq_xlmr_base/sentencepiece-18e17bae.model',
- 'params': 'fairseq_xlmr_base/model-340f4fa8.params'
+ 'params': 'fairseq_xlmr_base/model-3fa134e9.params',
+ 'mlm_params': 'fairseq_xlmr_base/model_mlm-86e37954.params',
+ 'lowercase': False,
},
'fairseq_xlmr_large': {
'cfg': 'fairseq_xlmr_large/model-01fc59fb.yml',
'sentencepiece.model': 'fairseq_xlmr_large/sentencepiece-18e17bae.model',
- 'params': 'fairseq_xlmr_large/model-e4b11125.params'
+ 'params': 'fairseq_xlmr_large/model-b62b074c.params',
+ 'mlm_params': 'fairseq_xlmr_large/model_mlm-887506c2.params',
+ 'lowercase': False,
}
}
FILE_STATS = load_checksum_stats(os.path.join(get_model_zoo_checksum_dir(), 'xlmr.txt'))
-xlmr_cfg_reg = Registry('roberta_cfg')
+xlmr_cfg_reg = Registry('xlmr_cfg')
@xlmr_cfg_reg.register()
@@ -82,14 +86,18 @@ def get_cfg(key=None):
return xlmr_cfg_reg.create(key)
else:
return xlmr_base()
-
+@use_np
+class XLMRForMLM(RobertaForMLM):
+ pass
def list_pretrained_xlmr():
return sorted(list(PRETRAINED_URL.keys()))
def get_pretrained_xlmr(model_name: str = 'fairseq_xlmr_base',
- root: str = get_model_zoo_home_dir()) \
+ root: str = get_model_zoo_home_dir(),
+ load_backbone: bool = True,
+ load_mlm: bool = False) \
-> Tuple[CN, SentencepieceTokenizer, str]:
"""Get the pretrained XLM-R weights
@@ -99,6 +107,10 @@ def get_pretrained_xlmr(model_name: str = 'fairseq_xlmr_base',
The name of the xlmr model.
root
The downloading root
+ load_backbone
+ Whether to load the weights of the backbone network
+ load_mlm
+ Whether to load the weights of MLM
Returns
-------
@@ -108,21 +120,40 @@ def get_pretrained_xlmr(model_name: str = 'fairseq_xlmr_base',
The SentencepieceTokenizer
params_path
Path to the parameters
+ mlm_params_path
+ Path to the parameter that includes both the backbone and the MLM
"""
assert model_name in PRETRAINED_URL, '{} is not found. All available are {}'.format(
model_name, list_pretrained_xlmr())
cfg_path = PRETRAINED_URL[model_name]['cfg']
sp_model_path = PRETRAINED_URL[model_name]['sentencepiece.model']
params_path = PRETRAINED_URL[model_name]['params']
+ mlm_params_path = PRETRAINED_URL[model_name]['mlm_params']
local_paths = dict()
- for k, path in [('cfg', cfg_path), ('sentencepiece.model', sp_model_path), \
- ('params', params_path)]:
+ for k, path in [('cfg', cfg_path), ('sentencepiece.model', sp_model_path)]:
local_paths[k] = download(url=get_repo_model_zoo_url() + path,
path=os.path.join(root, path),
sha1_hash=FILE_STATS[path])
- tokenizer = SentencepieceTokenizer(local_paths['sentencepiece.model'])
+ if load_backbone:
+ local_params_path = download(url=get_repo_model_zoo_url() + params_path,
+ path=os.path.join(root, params_path),
+ sha1_hash=FILE_STATS[params_path])
+ else:
+ local_params_path = None
+ if load_mlm and mlm_params_path is not None:
+ local_mlm_params_path = download(url=get_repo_model_zoo_url() + mlm_params_path,
+ path=os.path.join(root, mlm_params_path),
+ sha1_hash=FILE_STATS[mlm_params_path])
+ else:
+ local_mlm_params_path = None
+
+ do_lower = True if 'lowercase' in PRETRAINED_URL[model_name]\
+ and PRETRAINED_URL[model_name]['lowercase'] else False
+ tokenizer = SentencepieceTokenizer(
+ model_path=local_paths['sentencepiece.model'],
+ lowercase=do_lower)
cfg = XLMRModel.get_cfg().clone_merge(local_paths['cfg'])
- return cfg, tokenizer, local_paths['params']
+ return cfg, tokenizer, local_params_path, local_mlm_params_path
BACKBONE_REGISTRY.register('xlmr', [XLMRModel,
diff --git a/tests/test_data_tokenizers.py b/tests/test_data_tokenizers.py
index 6fa4b15961..148a82c396 100644
--- a/tests/test_data_tokenizers.py
+++ b/tests/test_data_tokenizers.py
@@ -117,14 +117,14 @@ def verify_decode_spm(tokenizer, all_sentences, gt_int_decode_sentences):
(all_sentences, gt_int_decode_sentences)]:
if isinstance(sentences, str):
gt_str_decode_sentences = sentences
- if tokenizer.do_lower:
+ if tokenizer.lowercase:
gt_str_decode_sentences = gt_str_decode_sentences.lower()
gt_str_decode_sentences = unicodedata.normalize('NFKC', gt_str_decode_sentences)
elif isinstance(sentences, list):
gt_str_decode_sentences = []
for ele in sentences:
ele_gt_decode = ele
- if tokenizer.do_lower:
+ if tokenizer.lowercase:
ele_gt_decode = ele_gt_decode.lower()
ele_gt_decode = unicodedata.normalize('NFKC', ele_gt_decode)
gt_str_decode_sentences.append(ele_gt_decode)
@@ -379,11 +379,11 @@ def test_sentencepiece_tokenizer():
gt_lower_case_int_decode = ['hello, y ⁇ all! how are you viii ⁇ ⁇ ⁇ ?',
'gluonnlp is great!!!!!!',
'gluonnlp-amazon-haibin-leonard-sheng-shuai-xingjian...../:! ⁇ # ⁇ abc ⁇ ']
- tokenizer = SentencepieceTokenizer(model_path, do_lower=True)
+ tokenizer = SentencepieceTokenizer(model_path, lowercase=True)
verify_decode_spm(tokenizer, SUBWORD_TEST_SAMPLES, gt_lower_case_int_decode)
# Case3, Use the sentencepiece regularization commands, we test whether we can obtain different encoding results
- tokenizer = SentencepieceTokenizer(model_path, do_lower=True, nbest=-1, alpha=1.0)
+ tokenizer = SentencepieceTokenizer(model_path, lowercase=True, nbest=-1, alpha=1.0)
has_different_encode_out = False
encode_out = None
for _ in range(10):
diff --git a/tests/test_models_roberta.py b/tests/test_models_roberta.py
index 2ae7d1cd97..9511c51472 100644
--- a/tests/test_models_roberta.py
+++ b/tests/test_models_roberta.py
@@ -2,7 +2,7 @@
import numpy as np
import mxnet as mx
import tempfile
-from gluonnlp.models.roberta import RobertaModel,\
+from gluonnlp.models.roberta import RobertaModel, RobertaForMLM, \
list_pretrained_roberta, get_pretrained_roberta
from gluonnlp.loss import LabelSmoothCrossEntropyLoss
@@ -19,12 +19,19 @@ def test_roberta(model_name):
# test from pretrained
assert len(list_pretrained_roberta()) > 0
with tempfile.TemporaryDirectory() as root:
- cfg, tokenizer, params_path =\
- get_pretrained_roberta(model_name, root=root)
+ cfg, tokenizer, params_path, mlm_params_path =\
+ get_pretrained_roberta(model_name, load_backbone=True, load_mlm=True, root=root)
assert cfg.MODEL.vocab_size == len(tokenizer.vocab)
+ # test backbone
roberta_model = RobertaModel.from_cfg(cfg)
roberta_model.load_parameters(params_path)
-
+ # test mlm model
+ roberta_mlm_model = RobertaForMLM(cfg)
+ if mlm_params_path is not None:
+ roberta_mlm_model.load_parameters(mlm_params_path)
+ roberta_mlm_model = RobertaForMLM(cfg)
+ roberta_mlm_model.backbone_model.load_parameters(params_path)
+
# test forward
batch_size = 3
seq_length = 32
@@ -45,12 +52,12 @@ def test_roberta(model_name):
),
dtype=np.int32
)
- x = roberta_model(input_ids, valid_length)
+ contextual_embeddings, pooled_out = roberta_model(input_ids, valid_length)
mx.npx.waitall()
# test backward
label_smooth_loss = LabelSmoothCrossEntropyLoss(num_labels=vocab_size)
with mx.autograd.record():
- x = roberta_model(input_ids, valid_length)
- loss = label_smooth_loss(x, input_ids)
+ contextual_embeddings, pooled_out = roberta_model(input_ids, valid_length)
+ loss = label_smooth_loss(contextual_embeddings, input_ids)
loss.backward()
mx.npx.waitall()
diff --git a/tests/test_models_xlmr.py b/tests/test_models_xlmr.py
index 74538ef327..f8f9ec76fe 100644
--- a/tests/test_models_xlmr.py
+++ b/tests/test_models_xlmr.py
@@ -2,7 +2,7 @@
import numpy as np
import mxnet as mx
import tempfile
-from gluonnlp.models.xlmr import XLMRModel,\
+from gluonnlp.models.xlmr import XLMRModel, XLMRForMLM, \
list_pretrained_xlmr, get_pretrained_xlmr
from gluonnlp.loss import LabelSmoothCrossEntropyLoss
@@ -19,10 +19,14 @@ def test_xlmr():
assert len(list_pretrained_xlmr()) > 0
for model_name in ['fairseq_xlmr_base']:
with tempfile.TemporaryDirectory() as root:
- cfg, tokenizer, params_path = get_pretrained_xlmr(model_name, root=root)
+ cfg, tokenizer, params_path, mlm_params_path =\
+ get_pretrained_xlmr(model_name, load_backbone=True, load_mlm=False, root=root)
assert cfg.MODEL.vocab_size == len(tokenizer.vocab)
+ # test backbone
xlmr_model = XLMRModel.from_cfg(cfg)
xlmr_model.load_parameters(params_path)
+ # pass the mlm model
+
# test forward
batch_size = 1
seq_length = 8
@@ -43,12 +47,12 @@ def test_xlmr():
),
dtype=np.int32
)
- x = xlmr_model(input_ids, valid_length)
+ contextual_embeddings, pooled_out = xlmr_model(input_ids, valid_length)
mx.npx.waitall()
# test backward
label_smooth_loss = LabelSmoothCrossEntropyLoss(num_labels=vocab_size)
with mx.autograd.record():
- x = xlmr_model(input_ids, valid_length)
- loss = label_smooth_loss(x, input_ids)
+ contextual_embeddings, pooled_out = xlmr_model(input_ids, valid_length)
+ loss = label_smooth_loss(contextual_embeddings, input_ids)
loss.backward()
mx.npx.waitall()