-
Notifications
You must be signed in to change notification settings - Fork 4
/
AKT.py
249 lines (211 loc) · 10.6 KB
/
AKT.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
# -*- coding: UTF-8 -*-
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from models.BaseModel import BaseModel
from utils import utils
class AKT(BaseModel):
extra_log_args = ['num_layer', 'num_head']
@staticmethod
def parse_model_args(parser, model_name='AKT'):
parser.add_argument('--emb_size', type=int, default=64,
help='Size of embedding vectors.')
parser.add_argument('--num_layer', type=int, default=1,
help='Self-attention layers.')
parser.add_argument('--num_head', type=int, default=4,
help='Self-attention heads.')
return BaseModel.parse_model_args(parser, model_name)
def __init__(self, args, corpus):
super().__init__(model_path=args.model_path)
self.skill_num = int(corpus.n_skills)
self.question_num = int(corpus.n_problems)
self.emb_size = args.emb_size
self.num_head = args.num_head
self.dropout = args.dropout
self.skill_embeddings = nn.Embedding(self.skill_num, self.emb_size)
self.inter_embeddings = nn.Embedding(self.skill_num * 2, self.emb_size)
self.difficult_param = nn.Embedding(self.question_num, 1)
self.skill_diff = nn.Embedding(self.skill_num, self.emb_size)
self.inter_diff = nn.Embedding(self.skill_num * 2, self.emb_size)
self.blocks_1 = nn.ModuleList([
TransformerLayer(d_model=self.emb_size, d_feature=self.emb_size // self.num_head, d_ff=self.emb_size,
dropout=self.dropout, n_heads=self.num_head, kq_same=False, gpu=args.gpu)
for _ in range(args.num_layer)
])
self.blocks_2 = nn.ModuleList([
TransformerLayer(d_model=self.emb_size, d_feature=self.emb_size // self.num_head, d_ff=self.emb_size,
dropout=self.dropout, n_heads=self.num_head, kq_same=False, gpu=args.gpu)
for _ in range(args.num_layer * 2)
])
self.out = nn.Sequential(
nn.Linear(self.emb_size * 2, 64), nn.ReLU(), nn.Dropout(self.dropout),
nn.Linear(64, 32), nn.ReLU(), nn.Dropout(self.dropout),
nn.Linear(32, 1)
)
self.loss_function = nn.BCELoss(reduction='sum')
def forward(self, feed_dict):
skills = feed_dict['skill_seq'] # [batch_size, real_max_step]
questions = feed_dict['quest_seq'] # [batch_size, real_max_step]
labels = feed_dict['label_seq'] # [batch_size, real_max_step]
mask_labels = labels * (labels > -1).long()
inters = skills + mask_labels * self.skill_num
skill_data = self.skill_embeddings(skills)
inter_data = self.inter_embeddings(inters)
skill_diff_data = self.skill_diff(skills)
inter_diff_data = self.inter_diff(inters)
q_diff = self.difficult_param(questions)
skill_data = skill_data + q_diff * skill_diff_data
inter_data = inter_data + q_diff * inter_diff_data
x, y = skill_data, inter_data
for block in self.blocks_1: # encode
y = block(mask=1, query=y, key=y, values=y)
flag_first = True
for block in self.blocks_2:
if flag_first: # peek current question
x = block(mask=1, query=x, key=x, values=x, apply_pos=False)
flag_first = False
else: # don't peek current response
x = block(mask=0, query=x, key=x, values=y, apply_pos=True)
flag_first = True
concat_q = torch.cat([x, skill_data], dim=-1)
prediction = self.out(concat_q).squeeze(-1).sigmoid()
out_dict = {'prediction': prediction[:, 1:], 'label': labels[:, 1:].double()}
return out_dict
def loss(self, feed_dict, outdict):
prediction = outdict['prediction'].flatten()
label = outdict['label'].flatten()
mask = label > -1
loss = self.loss_function(prediction[mask], label[mask])
return loss
def get_feed_dict(self, corpus, data, batch_start, batch_size, phase):
batch_end = min(len(data), batch_start + batch_size)
real_batch_size = batch_end - batch_start
skill_seqs = data['skill_seq'][batch_start: batch_start + real_batch_size].values
quest_seqs = data['problem_seq'][batch_start: batch_start + real_batch_size].values
label_seqs = data['correct_seq'][batch_start: batch_start + real_batch_size].values
feed_dict = {
'skill_seq': torch.from_numpy(utils.pad_lst(skill_seqs)), # [batch_size, real_max_step]
'quest_seq': torch.from_numpy(utils.pad_lst(quest_seqs)), # [batch_size, real_max_step]
'label_seq': torch.from_numpy(utils.pad_lst(label_seqs, value=-1)), # [batch_size, real_max_step]
}
return feed_dict
class TransformerLayer(nn.Module):
def __init__(self, d_model, d_feature, d_ff, n_heads, dropout, kq_same, gpu=''):
super().__init__()
"""
This is a Basic Block of Transformer paper. It containts one Multi-head attention object.
Followed by layer norm and postion wise feedforward net and dropout layer.
"""
self.gpu = gpu
# Multi-Head Attention Block
self.masked_attn_head = MultiHeadAttention(
d_model, d_feature, n_heads, dropout, kq_same=kq_same, gpu=gpu)
# Two layer norm layer and two droput layer
self.layer_norm1 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.linear1 = nn.Linear(d_model, d_ff)
self.activation = nn.ReLU()
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(d_ff, d_model)
self.layer_norm2 = nn.LayerNorm(d_model)
self.dropout2 = nn.Dropout(dropout)
def forward(self, mask, query, key, values, apply_pos=True):
seqlen, batch_size = query.size(1), query.size(0)
nopeek_mask = np.triu(
np.ones((1, 1, seqlen, seqlen)), k=mask).astype('uint8')
src_mask = (torch.from_numpy(nopeek_mask) == 0)
src_mask = src_mask.cuda() if self.gpu != '' else src_mask
if mask == 0: # If 0, zero-padding is needed.
# Calls block.masked_attn_head.forward() method
query2 = self.masked_attn_head(
query, key, values, mask=src_mask, zero_pad=True)
else:
# Calls block.masked_attn_head.forward() method
query2 = self.masked_attn_head(
query, key, values, mask=src_mask, zero_pad=False)
query = query + self.dropout1(query2)
query = self.layer_norm1(query)
if apply_pos:
query2 = self.linear2(self.dropout(
self.activation(self.linear1(query))))
query = query + self.dropout2(query2)
query = self.layer_norm2(query)
return query
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, d_feature, n_heads, dropout, kq_same, bias=True, gpu=''):
super().__init__()
"""
It has projection layer for getting keys, queries and values. Followed by attention and a connected layer.
"""
self.d_model = d_model
self.d_k = d_feature
self.h = n_heads
self.kq_same = kq_same
self.gpu = gpu
self.v_linear = nn.Linear(d_model, d_model, bias=bias)
self.k_linear = nn.Linear(d_model, d_model, bias=bias)
if kq_same is False:
self.q_linear = nn.Linear(d_model, d_model, bias=bias)
self.dropout = nn.Dropout(dropout)
self.proj_bias = bias
self.out_proj = nn.Linear(d_model, d_model, bias=bias)
self.gammas = nn.Parameter(torch.zeros(n_heads, 1, 1))
torch.nn.init.xavier_uniform_(self.gammas)
def forward(self, q, k, v, mask, zero_pad):
bs = q.size(0)
# perform linear operation and split into h heads
k = self.k_linear(k).view(bs, -1, self.h, self.d_k)
if self.kq_same is False:
q = self.q_linear(q).view(bs, -1, self.h, self.d_k)
else:
q = self.k_linear(q).view(bs, -1, self.h, self.d_k)
v = self.v_linear(v).view(bs, -1, self.h, self.d_k)
# transpose to get dimensions bs * h * sl * d_model
k = k.transpose(1, 2)
q = q.transpose(1, 2)
v = v.transpose(1, 2)
# calculate attention using function we will define next
scores = self.attention(q, k, v, self.d_k, mask, self.dropout, zero_pad, self.gammas)
# concatenate heads and put through final linear layer
concat = scores.transpose(1, 2).reshape(bs, -1, self.d_model)
output = self.out_proj(concat)
return output
def attention(self, q, k, v, d_k, mask, dropout, zero_pad, gamma=None):
"""
This is called by Multi-head attention object to find the values.
"""
scores = torch.matmul(q, k.transpose(-2, -1)) / d_k ** 0.5 # BS, head, seqlen, seqlen
bs, head, seqlen = scores.size(0), scores.size(1), scores.size(2)
x1 = torch.arange(seqlen).expand(seqlen, -1)
x1 = x1.cuda() if self.gpu != '' else x1
x2 = x1.transpose(0, 1).contiguous()
with torch.no_grad():
scores_ = F.softmax(scores, dim=-1) # BS,8,seqlen,seqlen
scores_ = scores_ * mask.double()
scores_ = scores_.cuda() if self.gpu != '' else scores_
distcum_scores = torch.cumsum(scores_, dim=-1) # bs, 8, sl, sl
disttotal_scores = torch.sum(
scores_, dim=-1, keepdim=True) # bs, 8, sl, 1
position_effect = torch.abs(
x1 - x2)[None, None, :, :].double() # 1, 1, seqlen, seqlen
position_effect = position_effect.cuda() if self.gpu != '' else position_effect
# bs, 8, sl, sl positive distance
dist_scores = torch.clamp(
(disttotal_scores - distcum_scores) * position_effect, min=0.)
dist_scores = dist_scores.sqrt().detach()
m = nn.Softplus()
gamma = -1. * m(gamma).unsqueeze(0) # 1,8,1,1
# Now after do exp(gamma*distance) and then clamp to 1e-5 to 1e5
total_effect = torch.clamp(torch.clamp(
(dist_scores * gamma).exp(), min=1e-5), max=1e5)
scores = scores * total_effect
scores.masked_fill_(mask == 0, -np.inf)
scores = F.softmax(scores, dim=-1) # BS, head, seqlen, seqlen
if zero_pad:
pad_zero = torch.zeros(bs, head, 1, seqlen).double()
pad_zero = pad_zero.cuda() if self.gpu != '' else pad_zero
scores = torch.cat([pad_zero, scores[:, :, 1:, :]], dim=2)
scores = dropout(scores)
output = torch.matmul(scores, v)
return output