-
Notifications
You must be signed in to change notification settings - Fork 0
/
hypo_select.py
111 lines (94 loc) · 3.53 KB
/
hypo_select.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import random
import time
import os
import datetime
import numpy as np
import pandas as pd
from transformers import get_linear_schedule_with_warmup
from transformers import BertForSequenceClassification, AdamW, BertConfig
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data import TensorDataset, random_split
from transformers import AutoTokenizer
import torch
from torch import nn
from sentence_transformers import SentenceTransformer, util
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
pair_model_path = "model_save/hypo_pair_gold"
pair_model = BertForSequenceClassification.from_pretrained(pair_model_path)
tokenizer = AutoTokenizer.from_pretrained(pair_model_path)
pair_model = nn.DataParallel(pair_model)
pair_model.to(device)
judge_model_path = "model_save/hypo_judge_gold"
judge_model = BertForSequenceClassification.from_pretrained(judge_model_path)
judge_tokenizer = AutoTokenizer.from_pretrained(judge_model_path)
judge_model = nn.DataParallel(judge_model)
judge_model.to(device)
model = SentenceTransformer('model_save/hypo-pair-paraphrase-distilroberta-base-v1')
def get_hypo_pair_score(para, cands):
paras = [para for i in range(len(cands))]
sentences1 = paras
sentences2 = cands
#Compute embedding for both lists
embeddings1 = model.encode(sentences1, convert_to_tensor=True)
embeddings2 = model.encode(sentences2, convert_to_tensor=True)
#Compute cosine-similarits
cosine_scores = util.pytorch_cos_sim(embeddings1, embeddings2)
#Output the pairs with their score
res = [cosine_scores[i][i].item() for i in range(len(sentences1))]
res = np.array(res)
return res
"""
def get_hypo_pair_score(para, cands):
# device = torch.device("cpu")
paras = [para for i in range(len(cands))]
# print(len(cands), end=',')
encoded_inputs = tokenizer(
cands, # the first sentence is hypo
paras, # the second sentence is para
add_special_tokens=True,
max_length=256,
truncation=True,
padding='max_length',
return_attention_mask=True,
return_tensors='pt'
)
input_ids = encoded_inputs['input_ids'].to(device)
attention_masks = encoded_inputs['attention_mask'].to(device)
with torch.no_grad():
result = pair_model(input_ids,
token_type_ids=None,
attention_mask=attention_masks,
return_dict=True)
logits = result.logits.detach().cpu()
m = nn.Softmax(dim=1)
prob = m(logits)
indices = torch.tensor([1])
pos_prob = torch.index_select(prob, 1, indices)
return pos_prob.flatten().numpy()
"""
def get_hypo_judge_score(cands):
encoded_inputs = judge_tokenizer(
cands,
add_special_tokens=True,
max_length=256,
truncation=True,
padding='max_length',
return_attention_mask=True,
return_tensors='pt'
)
input_ids = encoded_inputs['input_ids'].to(device)
attention_masks = encoded_inputs['attention_mask'].to(device)
with torch.no_grad():
result = judge_model(input_ids,
token_type_ids=None,
attention_mask=attention_masks,
return_dict=True)
logits = result.logits.detach().cpu()
m = nn.Softmax(dim=1)
prob = m(logits)
indices = torch.tensor([1])
pos_prob = torch.index_select(prob, 1, indices)
#print("logits",logits)
#print("prob", prob)
#print("pos_prob", pos_prob)
return pos_prob.flatten().numpy()