-
Notifications
You must be signed in to change notification settings - Fork 2
/
llama2_model.py
462 lines (373 loc) · 16.2 KB
/
llama2_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
from torch import nn
@dataclass
class ModelArgs:
dim: int = 4096
n_layers: int = 32
n_heads: int = 32
n_kv_heads: Optional[int] = None
vocab_size: int = -1 # defined later by tokenizer
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
ffn_dim_multiplier: Optional[float] = None
norm_eps: float = 1e-5
max_batch_size: int = 32
max_seq_len: int = 32768
# If `True`, then each transformer block init uses its layer ID, and if
# `False`, each uses the total number of transformer blocks
depth_init: bool = True
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
"""
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
and the end index 'end'. The 'theta' parameter scales the frequencies.
The returned tensor contains complex values in complex64 data type.
Args:
dim (int): Dimension of the frequency tensor.
end (int): End index for precomputing frequencies.
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
Returns:
torch.Tensor: Precomputed frequency tensor with complex exponentials.
"""
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
"""
Reshape frequency tensor for broadcasting it with another tensor.
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
for the purpose of broadcasting the frequency tensor during element-wise operations.
Args:
freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
x (torch.Tensor): Target tensor for broadcasting compatibility.
Returns:
torch.Tensor: Reshaped frequency tensor.
"""
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor.
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
returned as real tensors.
Args:
xq (torch.Tensor): Query tensor to apply rotary embeddings.
xk (torch.Tensor): Key tensor to apply rotary embeddings.
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
bs, slen, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, :, None, :]
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
)
class RMSNorm(nn.Module):
"""
Initialize the RMSNorm normalization layer.
Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.
"""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x: torch.Tensor):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x: torch.Tensor):
output = self._norm(x.float()).type_as(x)
return output * self.weight
def reset_parameters(self):
torch.nn.init.ones_(self.weight) # type: ignore
class Attention(nn.Module):
"""
Multi-head attention module.
Args:
model_args (ModelArgs): Model configuration arguments.
Attributes:
n_kv_heads (int): Number of key and value heads.
n_heads (int): Number of query heads.
n_local_kv_heads (int): Number of local key and value heads.
n_rep (int): Number of repetitions for local heads.
head_dim (int): Dimension size of each attention head.
wq (Linear): Linear transformation for queries.
wk (Linear): Linear transformation for keys.
wv (Linear): Linear transformation for values.
wo (Linear): Linear transformation for output.
"""
def __init__(self, model_args: ModelArgs):
super().__init__()
self.n_heads = model_args.n_heads
self.n_kv_heads = (
model_args.n_heads
if model_args.n_kv_heads is None
else model_args.n_kv_heads
)
self.n_rep = self.n_heads // self.n_kv_heads
self.head_dim = model_args.dim // model_args.n_heads
self.wq = nn.Linear(
model_args.dim, model_args.n_heads * self.head_dim, bias=False
)
self.wk = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(
model_args.n_heads * self.head_dim, model_args.dim, bias=False
)
def init_weights(self, init_std: float):
for linear in (self.wq, self.wk, self.wv):
nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02)
nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std)
def forward(
self,
x: torch.Tensor,
freqs_cis: torch.Tensor,
):
"""
Forward pass of the attention module.
Args:
x (torch.Tensor): Input tensor.
freqs_cis (torch.Tensor): Precomputed frequency tensor.
Returns:
torch.Tensor: Output tensor after attention.
"""
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
values = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
# we use casual mask for training
output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
output = output.transpose(
1, 2
).contiguous() # (bs, seqlen, n_local_heads, head_dim)
output = output.view(bsz, seqlen, -1)
return self.wo(output)
class FeedForward(nn.Module):
"""
FeedForward module
Args:
dim (int): Input dimension.
hidden_dim (int): Hidden dimension of the feedforward layer.
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
ffn_dim_multiplier (Optional[float]): Custom multiplier for hidden dimension. Defaults to None.
Attributes:
w1 (Linear): Linear transformation for the first layer.
w2 (Linear): Linear transformation for the second layer.
w3 (Linear): Linear transformation for the third layer.
"""
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int,
ffn_dim_multiplier: Optional[float],
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
def init_weights(self, init_std: float):
nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02)
for linear in (self.w2, self.w3):
nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std)
class TransformerBlock(nn.Module):
"""
TransformerBlock Module
Args:
layer_id (int): Identifier for the layer.
model_args (ModelArgs): Model configuration arguments.
Attributes:
n_heads (int): Number of attention heads.
dim (int): Dimension size of the model.
head_dim (int): Dimension size of each attention head.
attention (Attention): Attention module.
feed_forward (FeedForward): FeedForward module.
layer_id (int): Identifier for the layer.
attention_norm (RMSNorm): Layer normalization for attention output.
ffn_norm (RMSNorm): Layer normalization for feedforward output.
"""
def __init__(self, layer_id: int, model_args: ModelArgs):
super().__init__()
self.n_heads = model_args.n_heads
self.dim = model_args.dim
self.attention = Attention(model_args)
self.feed_forward = FeedForward(
dim=model_args.dim,
hidden_dim=4 * model_args.dim,
multiple_of=model_args.multiple_of,
ffn_dim_multiplier=model_args.ffn_dim_multiplier,
)
self.layer_id = layer_id
self.num_layers = model_args.n_layers
self.attention_norm = RMSNorm(
dim=model_args.dim, eps=model_args.norm_eps
)
self.ffn_norm = RMSNorm(
dim=model_args.dim, eps=model_args.norm_eps
)
if model_args.depth_init:
self.weight_init_std = 0.02 / (2 * (self.layer_id + 1)) ** 0.5
else:
self.weight_init_std = 0.02 / (2 * self.num_layers) ** 0.5
def forward(
self,
x: torch.Tensor,
freqs_cis: torch.Tensor,
):
"""
Perform a forward pass through the TransformerBlock.
Args:
x (torch.Tensor): Input tensor.
freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
Returns:
torch.Tensor: Output tensor after applying attention and feedforward layers.
"""
h = x + self.attention(self.attention_norm(x), freqs_cis)
out = h + self.feed_forward(self.ffn_norm(h))
return out
def init_weights(self):
for norm in (self.attention_norm, self.ffn_norm):
norm.reset_parameters()
self.attention.init_weights(self.weight_init_std)
self.feed_forward.init_weights(self.weight_init_std)
class Transformer(nn.Module):
"""
Transformer Module
Args:
model_args (ModelArgs): Model configuration arguments.
Attributes:
model_args (ModelArgs): Model configuration arguments.
vocab_size (int): Vocabulary size.
n_layers (int): Number of layers in the model.
tok_embeddings (ParallelEmbedding): Token embeddings.
layers (torch.nn.ModuleList): List of Transformer blocks.
norm (RMSNorm): Layer normalization for the model output.
output (ColumnParallelLinear): Linear layer for final output.
freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
"""
def __init__(self, model_args: ModelArgs):
super().__init__()
self.model_args = model_args
self.vocab_size = model_args.vocab_size
self.n_layers = model_args.n_layers
self.model_dim = model_args.dim
self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim)
self.register_buffer(
"freqs_cis",
precompute_freqs_cis(
model_args.dim // model_args.n_heads,
# Need to compute until at least the max token limit for generation
# (use 2x max sequence length to be safe)
model_args.max_seq_len * 2,
),
)
self.layers = torch.nn.ModuleList()
for layer_id in range(model_args.n_layers):
self.layers.append(TransformerBlock(layer_id, model_args))
self.norm = RMSNorm(
dim=model_args.dim, eps=model_args.norm_eps
)
self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False)
self.init_weights()
def init_weights(self):
"""
[Note: On ``init_weights`` vs. ``reset_parameters``]
Modules may define ``reset_parameters`` to initialize parameter values.
``reset_parameters`` is meant to only initialize directly owned
parameters/buffers, not those of their child modules, and it can be
used to give the initial values for these tensors.
Separately, users may want custom initialization for their modules,
different from that in ``reset_parameters``. For this, we define
``init_weights``. We only call it in the constructor of this
``Transformer`` root module to avoid reinitializing tensors.
"""
with torch.device(self.freqs_cis.device):
self.freqs_cis = precompute_freqs_cis(
self.model_args.dim // self.model_args.n_heads,
# Need to compute until at least the max token limit for generation
# (use 2x max sequence length to be safe)
self.model_args.max_seq_len * 2,
)
nn.init.normal_(self.tok_embeddings.weight)
for layer in self.layers:
layer.init_weights()
self.norm.reset_parameters()
final_out_std = self.model_args.dim**-0.5
cutoff_factor = 3
nn.init.trunc_normal_(
self.output.weight,
mean=0.0,
std=final_out_std,
a=-cutoff_factor * final_out_std,
b=cutoff_factor * final_out_std,
)
def forward(self, tokens: torch.Tensor):
"""
Perform a forward pass through the Transformer model.
Args:
tokens (torch.Tensor): Input token indices.
Returns:
torch.Tensor: Output logits after applying the Transformer model.
"""
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
self.freqs_cis = self.freqs_cis.to(h.device)
freqs_cis = self.freqs_cis[0:seqlen]
for layer in self.layers:
h = layer(h, freqs_cis)
h = self.norm(h)
output = self.output(h).float()
return output
@classmethod
def from_model_args(cls, model_args: ModelArgs) -> "Transformer":
"""
Initialize a Transformer model from a ModelArgs object.
Args:
model_args (ModelArgs): Model configuration arguments.
Returns:
Transformer: Transformer model.
"""
return cls(model_args)