-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
131 lines (121 loc) · 5.15 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import torch
import numpy as np
torch.autograd.set_detect_anomaly(True)
# https://discuss.pytorch.org/t/nested-list-of-variable-length-to-a-tensor/38699/21
def pad_tensors(tensors):
"""
Takes a list of `N` M-dimensional tensors (M<4) and returns a padded tensor.
The padded tensor is `M+1` dimensional with size `N, S1, S2, ..., SM`
where `Si` is the maximum value of dimension `i` amongst all tensors.
"""
rep = tensors[0]
padded_dim = []
for dim in range(rep.dim()):
max_dim = max([tensor.size(dim) for tensor in tensors])
padded_dim.append(max_dim)
padded_dim = [len(tensors)] + padded_dim
padded_tensor = torch.zeros(padded_dim)
padded_tensor = padded_tensor.type_as(rep)
for i, tensor in enumerate(tensors):
size = list(tensor.size())
if len(size) == 1:
padded_tensor[i, :size[0]] = tensor
elif len(size) == 2:
padded_tensor[i, :size[0], :size[1]] = tensor
elif len(size) == 3:
padded_tensor[i, :size[0], :size[1], :size[2]] = tensor
else:
raise ValueError('Padding is supported for upto 3D tensors at max.')
return padded_tensor
def ints_to_tensor(ints):
"""
Converts a nested list of integers to a padded tensor.
"""
if isinstance(ints, torch.Tensor):
return ints
if isinstance(ints, list):
if isinstance(ints[0], int):
return torch.LongTensor(ints)
if isinstance(ints[0], torch.Tensor):
return pad_tensors(ints)
if isinstance(ints[0], list):
return ints_to_tensor([ints_to_tensor(inti) for inti in ints])
def get_mask(node_num, max_edu_dist):
batch_size, max_num=node_num.size(0), node_num.max()
mask=torch.arange(max_num).unsqueeze(0).cuda()<node_num.unsqueeze(1)
mask=mask.unsqueeze(1).expand(batch_size, max_num, max_num)
mask=mask&mask.transpose(1,2)
mask = torch.tril(mask, -1)
if max_num > max_edu_dist:
mask = torch.triu(mask, max_edu_dist - max_num)
return mask
def compute_loss(link_scores, label_scores, graphs, mask, p=False, negative=False):
link_scores[~mask]=-1e9
label_mask=(graphs!=0)&mask
link_mask=label_mask.clone()
link_scores=torch.nn.functional.softmax(link_scores, dim=-1)
link_loss=-torch.log(link_scores[link_mask])
vocab_size=label_scores.size(-1)
label_loss=torch.nn.functional.cross_entropy(label_scores[label_mask].reshape(-1, vocab_size), graphs[label_mask].reshape(-1), reduction='none')
if negative:
negative_mask=(graphs==0)&mask
negative_loss=torch.nn.functional.cross_entropy(label_scores[negative_mask].reshape(-1, vocab_size), graphs[negative_mask].reshape(-1),reduction='mean')
return link_loss, label_loss, negative_loss
if p:
return link_loss, label_loss, torch.nn.functional.softmax(label_scores[label_mask],dim=-1)[torch.arange(label_scores[label_mask].size(0)),graphs[mask]]
return link_loss, label_loss
def record_eval_result(eval_matrix, predicted_result):
for k, v in eval_matrix.items():
if v is None:
if isinstance(predicted_result[k], dict):
eval_matrix[k] = [predicted_result[k]]
else:
eval_matrix[k] = predicted_result[k]
elif isinstance(v, list):
eval_matrix[k] += [predicted_result[k]]
else:
eval_matrix[k] = np.append(eval_matrix[k], predicted_result[k])
def tsinghua_F1(eval_matrix):
cnt_golden, cnt_pred, cnt_cor_bi, cnt_cor_multi = 0, 0, 0, 0
for hypothesis, reference, edu_num in zip(eval_matrix['hypothesis'], eval_matrix['reference'],
eval_matrix['edu_num']):
cnt = [0] * edu_num
for r in reference:
cnt[r[1]] += 1
for i in range(edu_num):
if cnt[i] == 0:
cnt_golden += 1
cnt_pred += 1
if cnt[0] == 0:
cnt_cor_bi += 1
cnt_cor_multi += 1
cnt_golden += len(reference)
cnt_pred += len(hypothesis)
for pair in hypothesis:
if pair in reference:
cnt_cor_bi += 1
if hypothesis[pair] == reference[pair]:
cnt_cor_multi += 1
prec_bi, recall_bi = cnt_cor_bi * 1. / cnt_pred, cnt_cor_bi * 1. / cnt_golden
f1_bi = 2 * prec_bi * recall_bi / (prec_bi + recall_bi)
prec_multi, recall_multi = cnt_cor_multi * 1. / cnt_pred, cnt_cor_multi * 1. / cnt_golden
f1_multi = 2 * prec_multi * recall_multi / (prec_multi + recall_multi)
return f1_bi, f1_multi
def conv_list2Dic(datalist):
dataDic = {}
for da in datalist:
dataDic[da['id']] = da
return dataDic
def write_selected_data(id_file, src_file, des_file):
import json
with open(id_file, 'r', encoding='utf8') as fr:
lines = fr.readlines()
ids_list = list(set(a.strip() for a in lines))
with open(src_file, 'r', encoding='utf8') as fr:
src_datas = json.load(fr)
src_dic = conv_list2Dic(src_datas)
des_datas = []
for id in ids_list:
des_datas.append(src_dic[id])
with open(des_file,'w',encoding = 'utf8') as fw:
json.dump(des_datas, fw, ensure_ascii=False)