Skip to content

Commit

Permalink
Merge pull request #14 from microsoft/raviskolli/ort
Browse files Browse the repository at this point in the history
Remove model specific changes for BERT and DistilBERT
  • Loading branch information
raviskolli authored May 19, 2021
2 parents 0b2532a + 239767d commit efc9019
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 10 deletions.
6 changes: 1 addition & 5 deletions src/transformers/models/bert/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1266,7 +1266,6 @@ def __init__(self, config):

self.bert = BertModel(config, add_pooling_layer=False)
self.cls = BertOnlyMLMHead(config)
self.ort = config.ort

self.init_weights()

Expand Down Expand Up @@ -1327,10 +1326,7 @@ def forward(
masked_lm_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss() # -100 index = padding token
if self.ort:
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size).to(torch.float32), labels.view(-1))
else:
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))

if not return_dict:
output = (prediction_scores,) + outputs[2:]
Expand Down
6 changes: 1 addition & 5 deletions src/transformers/models/distilbert/modeling_distilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,6 @@ def __init__(self, config):
self.vocab_transform = nn.Linear(config.dim, config.dim)
self.vocab_layer_norm = nn.LayerNorm(config.dim, eps=1e-12)
self.vocab_projector = nn.Linear(config.dim, config.vocab_size)
self.ort = config.ort

self.init_weights()

Expand Down Expand Up @@ -555,10 +554,7 @@ def forward(

mlm_loss = None
if labels is not None:
if self.ort:
mlm_loss = self.mlm_loss_fct(prediction_logits.view(-1, prediction_logits.size(-1)).to(torch.float32), labels.view(-1))
else:
mlm_loss = self.mlm_loss_fct(prediction_logits.view(-1, prediction_logits.size(-1)), labels.view(-1))
mlm_loss = self.mlm_loss_fct(prediction_logits.view(-1, prediction_logits.size(-1)), labels.view(-1))

if not return_dict:
output = (prediction_logits,) + dlbrt_output[1:]
Expand Down

0 comments on commit efc9019

Please sign in to comment.