-
Notifications
You must be signed in to change notification settings - Fork 12
/
net_models.py
293 lines (256 loc) · 13.2 KB
/
net_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
class GaussianFourierProjection(nn.Module):
"""Gaussian random features for encoding time steps."""
def __init__(self, embed_dim, scale=30.):
super().__init__()
# Randomly sample weights during initialization. These weights are fixed
# during optimization and are not trainable.
self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)
def forward(self, x):
x_proj = x[:, None] * self.W[None, :] * 2 * math.pi
return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
class Dense(nn.Module):
"""A fully connected layer that reshapes outputs to feature maps.
Allow time repr to input additively from the side of a convolution layer.
"""
def __init__(self, input_dim, output_dim):
super().__init__()
self.dense = nn.Linear(input_dim, output_dim)
def forward(self, x):
return self.dense(x)[..., None, None]
class CrossAttention(nn.Module):
"""General implementation of Cross & Self Attention"""
def __init__(self, embed_dim, hidden_dim, context_dim=None, num_heads=1, ):
super(CrossAttention, self).__init__()
self.hidden_dim = hidden_dim
self.context_dim = context_dim
self.embed_dim = embed_dim
self.query = nn.Linear(hidden_dim, embed_dim, bias=False)
if context_dim is None:
# Self Attention
self.key = nn.Linear(hidden_dim, embed_dim, bias=False)
self.value = nn.Linear(hidden_dim, hidden_dim, bias=False)
self.self_attn = True
else:
# Cross Attention
self.key = nn.Linear(context_dim, embed_dim, bias=False)
self.value = nn.Linear(context_dim, hidden_dim, bias=False)
self.self_attn = False
# self.query = nn.Conv1d(hidden_dim, embed_dim, 1, bias=False)
# if context_dim is None:
# self.key = nn.Conv1d(hidden_dim, embed_dim, 1, bias=False)
# self.value = nn.Conv1d(hidden_dim, hidden_dim, 1, bias=False)
# self.self_attn = True
# else:
# self.key = nn.Conv1d(context_dim, embed_dim, 1, bias=False)
# self.value = nn.Conv1d(context_dim, hidden_dim, 1, bias=False)
# self.self_attn = False
def forward(self, tokens, context=None):
Q = self.query(tokens)
K = self.key(tokens) if self.self_attn else self.key(context)
V = self.value(tokens) if self.self_attn else self.value(context)
# if self.self_attn:
# print(Q.shape, K.shape, V.shape)
scoremats = torch.einsum("BTH,BSH->BTS", Q, K)
attnmats = F.softmax(scoremats / math.sqrt(self.embed_dim), dim=-1)
# print(scoremats.shape, attnmats.shape, )
ctx_vecs = torch.einsum("BTS,BSH->BTH", attnmats, V)
return ctx_vecs
class TransformerBlock(nn.Module):
def __init__(self, hidden_dim, context_dim):
super(TransformerBlock, self).__init__()
self.attn_self = CrossAttention(hidden_dim, hidden_dim, )
self.attn_cross = CrossAttention(hidden_dim, hidden_dim, context_dim)
self.norm1 = nn.LayerNorm(hidden_dim)
self.norm2 = nn.LayerNorm(hidden_dim)
self.norm3 = nn.LayerNorm(hidden_dim)
self.ffn = nn.Sequential(
nn.Linear(hidden_dim, 3 * hidden_dim),
nn.GELU(),
nn.Linear(3 * hidden_dim, hidden_dim)
)
def forward(self, x, context=None):
x = self.attn_self(self.norm1(x)) + x
x = self.attn_cross(self.norm2(x), context=context) + x
x = self.ffn(self.norm3(x)) + x
return x
class SpatialTransformer(nn.Module):
def __init__(self, hidden_dim, context_dim):
super(SpatialTransformer, self).__init__()
self.transformer = TransformerBlock(hidden_dim, context_dim)
def forward(self, x, context=None):
b, c, h, w = x.shape
x_in = x
# context = rearrange(context, "b c T -> b T c")
x = rearrange(x, "b c h w->b (h w) c")
x = self.transformer(x, context)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
return x + x_in
class UNet_Tranformer_attrb(nn.Module):
"""A time-dependent score-based model built upon U-Net architecture."""
def __init__(self, marginal_prob_std, channels=[32, 64, 128, 256], embed_dim=256,
text_dim=256, nAttr=40):
"""Initialize a time-dependent score-based network.
Args:
marginal_prob_std: A function that takes time t and gives the standard
deviation of the perturbation kernel p_{0t}(x(t) | x(0)).
channels: The number of channels for feature maps of each resolution.
embed_dim: The dimensionality of Gaussian random feature embeddings.
"""
super().__init__()
# Gaussian random feature embedding layer for time
self.embed = nn.Sequential(GaussianFourierProjection(embed_dim=embed_dim),
nn.Linear(embed_dim, embed_dim))
# Encoding layers where the resolution decreases
self.conv1 = nn.Conv2d(3, channels[0], 3, stride=1, bias=False, )
self.dense1 = Dense(embed_dim, channels[0])
self.gnorm1 = nn.GroupNorm(4, num_channels=channels[0])
self.conv2 = nn.Conv2d(channels[0], channels[1], 3, stride=2, bias=False, )
self.dense2 = Dense(embed_dim, channels[1])
self.gnorm2 = nn.GroupNorm(32, num_channels=channels[1])
self.conv3 = nn.Conv2d(channels[1], channels[2], 3, stride=2, bias=False, )
self.dense3 = Dense(embed_dim, channels[2])
self.gnorm3 = nn.GroupNorm(32, num_channels=channels[2])
self.attn3 = SpatialTransformer(channels[2], text_dim)
self.conv4 = nn.Conv2d(channels[2], channels[3], 3, stride=2, bias=False, )
self.dense4 = Dense(embed_dim, channels[3])
self.gnorm4 = nn.GroupNorm(32, num_channels=channels[3])
self.attn4 = SpatialTransformer(channels[3], text_dim)
# Decoding layers where the resolution increases
self.tconv4 = nn.ConvTranspose2d(channels[3], channels[2], 3, stride=2, bias=False, output_padding=1)
self.dense5 = Dense(embed_dim, channels[2])
self.tgnorm4 = nn.GroupNorm(32, num_channels=channels[2])
self.attn5 = SpatialTransformer(channels[2], text_dim)
self.tconv3 = nn.ConvTranspose2d(channels[2], channels[1], 3, stride=2, bias=False,
output_padding=1) # , output_padding=1) # + channels[2]
self.dense6 = Dense(embed_dim, channels[1])
self.tgnorm3 = nn.GroupNorm(32, num_channels=channels[1])
self.tconv2 = nn.ConvTranspose2d(channels[1], channels[0], 3, stride=2, bias=False,
output_padding=1) # , output_padding=1) # + channels[1]
self.dense7 = Dense(embed_dim, channels[0])
self.tgnorm2 = nn.GroupNorm(32, num_channels=channels[0])
self.tconv1 = nn.ConvTranspose2d(channels[0], 3, 3, stride=1, ) # + channels[0]
# The swish activation function
self.act = nn.SiLU() # lambda x: x * torch.sigmoid(x)
self.marginal_prob_std = marginal_prob_std
self.cond_embed = nn.Embedding(nAttr + 1, text_dim,
padding_idx=nAttr) # +1 for the padding index
def forward(self, x, t, y=None):
# Obtain the Gaussian random feature embedding for t
embed = self.act(self.embed(t))
y_embed = self.cond_embed(y) # .unsqueeze(1)
# Encoding path
h1 = self.conv1(x) + self.dense1(embed)
## Incorporate information from t
## Group normalization
h1 = self.act(self.gnorm1(h1))
h2 = self.conv2(h1) + self.dense2(embed)
h2 = self.act(self.gnorm2(h2))
h3 = self.conv3(h2) + self.dense3(embed)
h3 = self.act(self.gnorm3(h3))
# h3 = self.attn3(h3, y_embed)
h4 = self.conv4(h3) + self.dense4(embed)
h4 = self.act(self.gnorm4(h4))
h4 = self.attn4(h4, y_embed)
# Decoding path
h = self.tconv4(h4) + self.dense5(embed)
## Skip connection from the encoding path
h = self.act(self.tgnorm4(h))
# h = self.attn5(h, y_embed)
h = self.tconv3(h + h3) + self.dense6(embed)
h = self.act(self.tgnorm3(h))
h = self.tconv2(h + h2) + self.dense7(embed)
h = self.act(self.tgnorm2(h))
h = self.tconv1(h + h1)
# Normalize output
h = h / self.marginal_prob_std(t)[:, None, None, None]
return h
class ResBlock(nn.Module):
def __init__(self, in_chan, out_chan, stride=1, downsample=None):
super().__init__()
self.conv1 = nn.Conv2d(in_chan, in_chan, 3, stride=1, padding=1)
self.conv2 = nn.Conv2d(in_chan, in_chan, 3, stride=1, padding=1)
self.conv3 = nn.Conv2d(in_chan, in_chan, 3, stride=1, padding=1)
class UNet_Tranformer_ResBlk_attrb(nn.Module):
"""A time-dependent score-based model built upon U-Net architecture."""
def __init__(self, marginal_prob_std, channels=[32, 64, 128, 256], embed_dim=256,
text_dim=256, nAttr=40):
"""Initialize a time-dependent score-based network.
Args:
marginal_prob_std: A function that takes time t and gives the standard
deviation of the perturbation kernel p_{0t}(x(t) | x(0)).
channels: The number of channels for feature maps of each resolution.
embed_dim: The dimensionality of Gaussian random feature embeddings.
"""
super().__init__()
# Gaussian random feature embedding layer for time
self.embed = nn.Sequential(GaussianFourierProjection(embed_dim=embed_dim),
nn.Linear(embed_dim, embed_dim))
# Encoding layers where the resolution decreases
self.conv1 = nn.Conv2d(3, channels[0], 3, stride=1, bias=False, )
self.dense1 = Dense(embed_dim, channels[0])
self.gnorm1 = nn.GroupNorm(4, num_channels=channels[0])
self.conv2 = nn.Conv2d(channels[0], channels[1], 3, stride=2, bias=False, )
self.dense2 = Dense(embed_dim, channels[1])
self.gnorm2 = nn.GroupNorm(32, num_channels=channels[1])
self.conv3 = nn.Conv2d(channels[1], channels[2], 3, stride=2, bias=False, )
self.dense3 = Dense(embed_dim, channels[2])
self.gnorm3 = nn.GroupNorm(32, num_channels=channels[2])
self.attn3 = SpatialTransformer(channels[2], text_dim)
self.conv4 = nn.Conv2d(channels[2], channels[3], 3, stride=2, bias=False, )
self.dense4 = Dense(embed_dim, channels[3])
self.gnorm4 = nn.GroupNorm(32, num_channels=channels[3])
self.attn4 = SpatialTransformer(channels[3], text_dim)
# Decoding layers where the resolution increases
self.tconv4 = nn.ConvTranspose2d(channels[3], channels[2], 3, stride=2, bias=False, output_padding=1)
self.dense5 = Dense(embed_dim, channels[2])
self.tgnorm4 = nn.GroupNorm(32, num_channels=channels[2])
self.attn5 = SpatialTransformer(channels[2], text_dim)
self.tconv3 = nn.ConvTranspose2d(channels[2], channels[1], 3, stride=2, bias=False,
output_padding=1) # , output_padding=1) # + channels[2]
self.dense6 = Dense(embed_dim, channels[1])
self.tgnorm3 = nn.GroupNorm(32, num_channels=channels[1])
self.tconv2 = nn.ConvTranspose2d(channels[1], channels[0], 3, stride=2, bias=False,
output_padding=1) # , output_padding=1) # + channels[1]
self.dense7 = Dense(embed_dim, channels[0])
self.tgnorm2 = nn.GroupNorm(32, num_channels=channels[0])
self.tconv1 = nn.ConvTranspose2d(channels[0], 3, 3, stride=1, ) # + channels[0]
# The swish activation function
self.act = nn.SiLU() # lambda x: x * torch.sigmoid(x)
self.marginal_prob_std = marginal_prob_std
self.cond_embed = nn.Embedding(nAttr + 1, text_dim,
padding_idx=nAttr) # +1 for the padding index
def forward(self, x, t, y=None):
# Obtain the Gaussian random feature embedding for t
embed = self.act(self.embed(t))
y_embed = self.cond_embed(y) # .unsqueeze(1)
# Encoding path
h1 = self.conv1(x) + self.dense1(embed)
## Incorporate information from t
## Group normalization
h1 = self.act(self.gnorm1(h1))
h2 = self.conv2(h1) + self.dense2(embed)
h2 = self.act(self.gnorm2(h2))
h3 = self.conv3(h2) + self.dense3(embed)
h3 = self.act(self.gnorm3(h3))
# h3 = self.attn3(h3, y_embed)
h4 = self.conv4(h3) + self.dense4(embed)
h4 = self.act(self.gnorm4(h4))
h4 = self.attn4(h4, y_embed)
# Decoding path
h = self.tconv4(h4) + self.dense5(embed)
## Skip connection from the encoding path
h = self.act(self.tgnorm4(h))
# h = self.attn5(h, y_embed)
h = self.tconv3(h + h3) + self.dense6(embed)
h = self.act(self.tgnorm3(h))
h = self.tconv2(h + h2) + self.dense7(embed)
h = self.act(self.tgnorm2(h))
h = self.tconv1(h + h1)
# Normalize output
h = h / self.marginal_prob_std(t)[:, None, None, None]
return h