-
Notifications
You must be signed in to change notification settings - Fork 21
/
Fastformer.py
53 lines (46 loc) · 2.21 KB
/
Fastformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import einops
from einops import rearrange
import torch
import torch.nn as nn
class Fastformer(nn.Module):
def __init__(self, dim = 3, decode_dim = 16):
super(Fastformer, self).__init__()
# Generate weight for Wquery、Wkey and Wvalue
self.to_qkv = nn.Linear(dim, decode_dim * 3, bias = False)
self.weight_q = nn.Linear(dim, decode_dim, bias = False)
self.weight_k = nn.Linear(dim, decode_dim, bias = False)
self.weight_v = nn.Linear(dim, decode_dim, bias = False)
self.weight_r = nn.Linear(decode_dim, decode_dim, bias = False)
self.weight_alpha = nn.Parameter(torch.randn(decode_dim))
self.weight_beta = nn.Parameter(torch.randn(decode_dim))
self.scale_factor = decode_dim ** -0.5
def forward(self, x, mask = None):
query = self.weight_q(x)
key = self.weight_k(x)
value = self.weight_v(x)
b, n, d = query.shape
mask_value = -torch.finfo(x.dtype).max
mask = rearrange(mask, 'b n -> b () n')
# Caculate the global query
alpha_weight = (torch.mul(query, self.weight_alpha) * self.scale_factor).masked_fill(~mask, mask_value)
alpha_weight = torch.softmax(alpha_weight, dim = -1)
global_query = query * alpha_weight
global_query = torch.einsum('b n d -> b d', global_query)
# Model the interaction between global query vector and the key vector
repeat_global_query = einops.repeat(global_query, 'b d -> b copy d', copy = n)
p = repeat_global_query * key
beta_weight = (torch.mul(p, self.weight_beta) * self.scale_factor).masked_fill(~mask, mask_value)
beta_weight = torch.softmax(beta_weight, dim = -1)
global_key = p * beta_weight
global_key = torch.einsum('b n d -> b d', global_key)
# key-value
key_value_interaction = torch.einsum('b j, b n j -> b n j', global_key, value)
key_value_interaction_out = self.weight_r(key_value_interaction)
result = key_value_interaction_out + query
return result
if __name__ == '__main__':
model = Fastformer(dim = 3, decode_dim = 8)
x = torch.randn(4, 6, 3)
mask = torch.ones(1, 8).bool()
result = model(x, mask)
print(result[0])