-
Notifications
You must be signed in to change notification settings - Fork 13
/
05.mha.sql
178 lines (178 loc) · 5.59 KB
/
05.mha.sql
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
WITH embeddings AS
(
SELECT place, values
FROM UNNEST(ARRAY[6307, 47701, 318, 1049]) WITH ORDINALITY AS tokens (token, ordinality)
CROSS JOIN LATERAL
(
SELECT ordinality - 1 AS place
) o
CROSS JOIN LATERAL
(
SELECT wte.values + wpe.values AS values
FROM wte
CROSS JOIN
wpe
WHERE wte.token = tokens.token
AND wpe.place = o.place
) embedding
),
c_proj_w AS
(
SELECT *
FROM c_proj_w
WHERE block = 0
),
c_proj_b AS
(
SELECT *
FROM c_proj_b
WHERE block = 0
),
mlp_c_fc_w AS
(
SELECT *
FROM mlp_c_fc_w
WHERE block = 0
),
mlp_c_fc_b AS
(
SELECT *
FROM mlp_c_fc_b
WHERE block = 0
),
mlp_c_proj_w AS
(
SELECT *
FROM mlp_c_proj_w
WHERE block = 0
),
mlp_c_proj_b AS
(
SELECT *
FROM mlp_c_proj_b
WHERE block = 0
),
c_attn_w AS
(
SELECT *
FROM c_attn_w
WHERE block = 0
),
c_attn_b AS
(
SELECT *
FROM c_attn_b
WHERE block = 0
),
ln_1_g AS
(
SELECT *
FROM ln_1_g
WHERE block = 0
),
ln_1_b AS
(
SELECT *
FROM ln_1_b
WHERE block = 0
),
mha_norm AS
(
SELECT place, mm.values + c_attn_b.values AS values
FROM (
SELECT place, ARRAY_AGG(INNER_PRODUCT(c_attn_w.values, layer_norm.values) ORDER BY y)::VECTOR(2304) AS values
FROM (
SELECT place, agg.values * ln_1_g.values + ln_1_b.values AS values
FROM (
SELECT place, norm.values
FROM embeddings
CROSS JOIN LATERAL
(
SELECT AVG(value) AS mean,
VAR_POP(value) AS variance
FROM UNNEST(values::REAL[]) value
) agg
CROSS JOIN LATERAL
(
SELECT ARRAY_AGG((value - mean) / SQRT(variance + 1E-5) ORDER BY ordinality)::VECTOR(768) AS values
FROM UNNEST(values::REAL[]) WITH ORDINALITY AS n(value, ordinality)
) norm
) agg
CROSS JOIN
ln_1_b
CROSS JOIN
ln_1_g
) layer_norm
CROSS JOIN
c_attn_w
GROUP BY
place
) mm
CROSS JOIN
c_attn_b
),
heads AS
(
SELECT place, head,
(values::REAL[])[(head * 64 + 1):(head * 64 + 64)]::VECTOR(64) AS q,
(values::REAL[])[(head * 64 + 1 + 768):(head * 64 + 64 + 768)]::VECTOR(64) AS k,
(values::REAL[])[(head * 64 + 1 + 1536):(head * 64 + 64 + 1536)]::VECTOR(64) AS v
FROM mha_norm
CROSS JOIN
GENERATE_SERIES(0, 11) head
),
sm_input AS
(
SELECT head, h1.place AS x, h2.place AS y, INNER_PRODUCT(h1.q, h2.k) / 8 + CASE WHEN h2.place > h1.place THEN -1E10 ELSE 0 END AS value
FROM heads h1
JOIN heads h2
USING (head)
),
sm_diff AS
(
SELECT head, x, y, value - MAX(value) OVER (PARTITION BY head, x) AS diff
FROM sm_input
),
sm_exp AS
(
SELECT head, x, y, CASE WHEN diff < -745.13 THEN 0 ELSE EXP(diff) END AS e
FROM sm_diff
),
softmax AS
(
SELECT head, x, y AS place, e / SUM(e) OVER (PARTITION BY head, x) AS value
FROM sm_exp
),
attention AS
(
SELECT place, ARRAY_AGG(value ORDER BY head * 64 + ordinality)::VECTOR(768) AS values
FROM (
SELECT head, x AS place, SUM(ARRAY_FILL(softmax.value, ARRAY[64])::VECTOR(64) * heads.v) AS values
FROM softmax
JOIN heads
USING (head, place)
GROUP BY
head, x
) q
CROSS JOIN LATERAL
UNNEST(values::REAL[]) WITH ORDINALITY v (value, ordinality)
GROUP BY
place
),
mha AS
(
SELECT place, w.values + c_proj_b.values AS values
FROM (
SELECT attention.place, ARRAY_AGG(INNER_PRODUCT(attention.values, c_proj_w.values) ORDER BY c_proj_w.place)::VECTOR(768) AS values
FROM attention
CROSS JOIN
c_proj_w
GROUP BY
attention.place
) w
CROSS JOIN
c_proj_b
)
SELECT place,
(SELECT STRING_AGG(TO_CHAR(n, 'S0.000'), ' ') || ' …' FROM UNNEST((values::REAL[])[:10]) AS n) AS q
FROM mha