-
Notifications
You must be signed in to change notification settings - Fork 0
/
vit_transformers.py
106 lines (85 loc) · 3.51 KB
/
vit_transformers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
def split_last(x, shape):
"split the last dimension to given shape"
shape = list(shape)
assert shape.count(-1) <= 1
if -1 in shape:
shape[shape.index(-1)] = int(x.size(-1) / -np.prod(shape))
return x.view(*x.size()[:-1], *shape)
def merge_last(x, n_dims):
"merge the last n_dims to a dimension"
s = x.size()
assert n_dims > 1 and n_dims < len(s)
return x.view(*s[:-n_dims], -1)
class MultiHeadedSelfAttention(nn.Module):
"""Multi-Headed Dot Product Attention"""
def __init__(self, dim, num_heads, dropout):
super().__init__()
self.proj_q = nn.Linear(dim, dim)
self.proj_k = nn.Linear(dim, dim)
self.proj_v = nn.Linear(dim, dim)
self.drop = nn.Dropout(dropout)
self.n_heads = num_heads
self.scores = None # for visualization
def forward(self, x, mask):
"""
x, q(query), k(key), v(value) : (B(batch_size), S(seq_len), D(dim))
mask : (B(batch_size) x S(seq_len))
* split D(dim) into (H(n_heads), W(width of head)) ; D = H * W
"""
# (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)
q, k, v = self.proj_q(x), self.proj_k(x), self.proj_v(x)
q, k, v = (split_last(x, (self.n_heads, -1)).transpose(1, 2) for x in [q, k, v])
# (B, H, S, W) @ (B, H, W, S) -> (B, H, S, S) -softmax-> (B, H, S, S)
scores = q @ k.transpose(-2, -1) / np.sqrt(k.size(-1))
if mask is not None:
mask = mask[:, None, None, :].float()
scores -= 10000.0 * (1.0 - mask)
scores = self.drop(F.softmax(scores, dim=-1))
# (B, H, S, S) @ (B, H, S, W) -> (B, H, S, W) -trans-> (B, S, H, W)
h = (scores @ v).transpose(1, 2).contiguous()
# -merge-> (B, S, D)
h = merge_last(h, 2)
self.scores = scores
return h
class PositionWiseFeedForward(nn.Module):
"""FeedForward Neural Networks for each position"""
def __init__(self, dim, ff_dim):
super().__init__()
self.fc1 = nn.Linear(dim, ff_dim)
self.fc2 = nn.Linear(ff_dim, dim)
def forward(self, x):
# (B, S, D) -> (B, S, D_ff) -> (B, S, D)
return self.fc2(self.gelu(self.fc1(x)))
# GELU only support pytorch >=1.7.0,our model use 1.0.0,so define it.
def gelu(self, x):
return 0.5 * x * (1 + torch.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * torch.pow(x, 3))))
class Block(nn.Module):
"""Transformer Block"""
def __init__(self, dim, num_heads, ff_dim, dropout):
super().__init__()
self.attn = MultiHeadedSelfAttention(dim, num_heads, dropout)
self.proj = nn.Linear(dim, dim)
self.norm1 = nn.LayerNorm(dim, eps=1e-6)
self.pwff = PositionWiseFeedForward(dim, ff_dim)
self.norm2 = nn.LayerNorm(dim, eps=1e-6)
self.drop = nn.Dropout(dropout)
def forward(self, x, mask):
h = self.drop(self.proj(self.attn(self.norm1(x), mask)))
x = x + h
h = self.drop(self.pwff(self.norm2(x)))
x = x + h
return x
class Transformer(nn.Module):
"""Transformer with Self-Attentive Blocks"""
def __init__(self, num_layers, dim, num_heads, ff_dim, dropout):
super().__init__()
self.blocks = nn.ModuleList([
Block(dim, num_heads, ff_dim, dropout) for _ in range(num_layers)])
def forward(self, x, mask=None):
for block in self.blocks:
x = block(x, mask)
return x