forked from dome272/VQGAN-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
mingpt.py
171 lines (137 loc) · 6.91 KB
/
mingpt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
"""
taken from: https://github.com/karpathy/minGPT/
GPT model:
- the initial stem consists of a combination of token encoding and a positional encoding
- the meat of it is a uniform sequence of Transformer blocks
- each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block
- all blocks feed into a central residual pathway similar to resnets
- the final decoder is a linear projection into a vanilla Softmax classifier
"""
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
class GPTConfig:
""" base GPT config, params common to all GPT versions """
embd_pdrop = 0.1
resid_pdrop = 0.1
attn_pdrop = 0.1
def __init__(self, vocab_size, block_size, **kwargs):
self.vocab_size = vocab_size
self.block_size = block_size
for k, v in kwargs.items():
setattr(self, k, v)
class CausalSelfAttention(nn.Module):
"""
A vanilla multi-head masked self-attention layer with a projection at the end.
It is possible to use torch.nn.MultiheadAttention here but I am including an
explicit implementation here to show that there is nothing too scary here.
"""
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
# key, query, value projections for all heads
self.key = nn.Linear(config.n_embd, config.n_embd)
self.query = nn.Linear(config.n_embd, config.n_embd)
self.value = nn.Linear(config.n_embd, config.n_embd)
# regularization
self.attn_drop = nn.Dropout(config.attn_pdrop)
self.resid_drop = nn.Dropout(config.resid_pdrop)
# output projection
self.proj = nn.Linear(config.n_embd, config.n_embd)
# causal mask to ensure that attention is only applied to the left in the input sequence
mask = torch.tril(torch.ones(config.block_size,
config.block_size))
if hasattr(config, "n_unmasked"):
mask[:config.n_unmasked, :config.n_unmasked] = 1
self.register_buffer("mask", mask.view(1, 1, config.block_size, config.block_size))
self.n_head = config.n_head
def forward(self, x, layer_past=None):
B, T, C = x.size()
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
present = torch.stack((k, v))
if layer_past is not None:
past_key, past_value = layer_past
k = torch.cat((past_key, k), dim=-2)
v = torch.cat((past_value, v), dim=-2)
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
if layer_past is None:
att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))
att = F.softmax(att, dim=-1)
att = self.attn_drop(att)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
# output projection
y = self.resid_drop(self.proj(y))
return y, present # TODO: check that this does not break anything
class Block(nn.Module):
""" an unassuming Transformer block """
def __init__(self, config):
super().__init__()
self.ln1 = nn.LayerNorm(config.n_embd)
self.ln2 = nn.LayerNorm(config.n_embd)
self.attn = CausalSelfAttention(config)
self.mlp = nn.Sequential(
nn.Linear(config.n_embd, 4 * config.n_embd),
nn.GELU(), # nice
nn.Linear(4 * config.n_embd, config.n_embd),
nn.Dropout(config.resid_pdrop),
)
def forward(self, x, layer_past=None, return_present=False):
# TODO: check that training still works
if return_present:
assert not self.training
# layer past: tuple of length two with B, nh, T, hs
attn, present = self.attn(self.ln1(x), layer_past=layer_past)
x = x + attn
x = x + self.mlp(self.ln2(x))
if layer_past is not None or return_present:
return x, present
return x
class GPT(nn.Module):
""" the full GPT language model, with a context size of block_size """
def __init__(self, vocab_size, block_size, n_layer=12, n_head=8, n_embd=256,
embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0):
super().__init__()
config = GPTConfig(vocab_size=vocab_size, block_size=block_size,
embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop,
n_layer=n_layer, n_head=n_head, n_embd=n_embd,
n_unmasked=n_unmasked)
# input embedding stem
self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd)) # 512 x 1024
self.drop = nn.Dropout(config.embd_pdrop)
# transformer
self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
# decoder head
self.ln_f = nn.LayerNorm(config.n_embd)
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.block_size = config.block_size
self.apply(self._init_weights)
self.config = config
def get_block_size(self):
return self.block_size
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def forward(self, idx, embeddings=None):
token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
if embeddings is not None: # prepend explicit embeddings
token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
t = token_embeddings.shape[1]
assert t <= self.block_size, "Cannot forward, model block size is exhausted."
position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
x = self.drop(token_embeddings + position_embeddings)
x = self.blocks(x)
x = self.ln_f(x)
logits = self.head(x)
return logits, None