-
Notifications
You must be signed in to change notification settings - Fork 0
/
p3D.py
398 lines (327 loc) · 14 KB
/
p3D.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
# -*- coding: utf-8 -*-
#modified done
#2019.06.03
#haidong
#https://github.com/qijiezhao/pseudo-3d-pytorch
"""
本文所做的事情
提出了用(1,3,3)的空间卷积和(3,1,1)的时间卷积来近似代替(3,3,3)的3D卷积的思想;
将现有的2D Conv扩展成3D Conv,并用以上伪3D的思想实现了一个新的Pseudo-3D ResNet(P3D);
这样既能够利用3D结构来提取视频的空间时序信息,又能利用原来在Imagenet上预训练的参数做模型初始化;
文章的P3D网络结构最后在Sports-1M视频分类数据集上达到的精度分别比原来的3D CNN和基于帧的2D CNN网络结构高出5.3%和1.8%。
除此文章还选用在5个不同的数据集上进行了3个不同的任务来验证模型的泛化能力,并且都达到了不错的结果
"""
from __future__ import print_function
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from torch.autograd import Variable
import math
from functools import partial
__all__ = ['P3D', 'P3D63', 'P3D131','P3D199']
def conv_S(in_planes,out_planes,stride=1,padding=1):
# as is descriped, conv S is 1x3x3
return nn.Conv3d(in_planes,out_planes,kernel_size=(1,3,3),stride=1,
padding=padding,bias=False)
def conv_T(in_planes,out_planes,stride=1,padding=1):
# conv T is 3x1x1
return nn.Conv3d(in_planes,out_planes,kernel_size=(3,1,1),stride=1,
padding=padding,bias=False)
def downsample_basic_block(x, planes, stride):
out = F.avg_pool3d(x, kernel_size=1, stride=stride)
zero_pads = torch.Tensor(out.size(0), planes - out.size(1),
out.size(2), out.size(3),
out.size(4)).zero_()
if isinstance(out.data, torch.cuda.FloatTensor):
zero_pads = zero_pads.cuda()
out = Variable(torch.cat([out.data, zero_pads], dim=1))
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None,n_s=0,depth_3d=47,ST_struc=('A','B','C')):
super(Bottleneck, self).__init__()
self.downsample = downsample
self.depth_3d=depth_3d
self.ST_struc=ST_struc
self.len_ST=len(self.ST_struc)
stride_p=stride
if not self.downsample ==None:
stride_p=(1,2,2)
if n_s<self.depth_3d:
if n_s==0:
stride_p=1
self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False,stride=stride_p)
self.bn1 = nn.BatchNorm3d(planes)
else:
if n_s==self.depth_3d:
stride_p=2
else:
stride_p=1
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False,stride=stride_p)
self.bn1 = nn.BatchNorm2d(planes)
# self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, stride=stride,
# padding=1, bias=False)
self.id=n_s
self.ST=list(self.ST_struc)[self.id%self.len_ST]
if self.id<self.depth_3d:
self.conv2 = conv_S(planes,planes, stride=1,padding=(0,1,1))
self.bn2 = nn.BatchNorm3d(planes)
#
self.conv3 = conv_T(planes,planes, stride=1,padding=(1,0,0))
self.bn3 = nn.BatchNorm3d(planes)
else:
self.conv_normal = nn.Conv2d(planes, planes, kernel_size=3, stride=1,padding=1,bias=False)
self.bn_normal = nn.BatchNorm2d(planes)
if n_s<self.depth_3d:
self.conv4 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False)
self.bn4 = nn.BatchNorm3d(planes * 4)
else:
self.conv4 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn4 = nn.BatchNorm2d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.stride = stride
def ST_A(self,x):
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.conv3(x)
x = self.bn3(x)
x = self.relu(x)
return x
def ST_B(self,x):
tmp_x = self.conv2(x)
tmp_x = self.bn2(tmp_x)
tmp_x = self.relu(tmp_x)
x = self.conv3(x)
x = self.bn3(x)
x = self.relu(x)
return x+tmp_x
def ST_C(self,x):
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
tmp_x = self.conv3(x)
tmp_x = self.bn3(tmp_x)
tmp_x = self.relu(tmp_x)
return x+tmp_x
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
# out = self.conv2(out)
# out = self.bn2(out)
# out = self.relu(out)
if self.id<self.depth_3d: # C3D parts:
if self.ST=='A':
out=self.ST_A(out)
elif self.ST=='B':
out=self.ST_B(out)
elif self.ST=='C':
out=self.ST_C(out)
else:
out = self.conv_normal(out) # normal is res5 part, C2D all.
out = self.bn_normal(out)
out = self.relu(out)
out = self.conv4(out)
out = self.bn4(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class P3D(nn.Module):
def __init__(self, block, layers, modality='RGB',
shortcut_type='B', num_classes=400,dropout=0.5,ST_struc=('A','B','C')):
self.inplanes = 64
super(P3D, self).__init__()
# self.conv1 = nn.Conv3d(3, 64, kernel_size=7, stride=(1, 2, 2),
# padding=(3, 3, 3), bias=False)
self.input_channel = 3 if modality=='RGB' else 2 # 2 is for flow
self.ST_struc=ST_struc
self.conv1_custom = nn.Conv3d(self.input_channel, 64, kernel_size=(1,7,7), stride=(1,2,2),
padding=(0,3,3), bias=False)
self.depth_3d=sum(layers[:3])# C3D layers are only (res2,res3,res4), res5 is C2D
self.bn1 = nn.BatchNorm3d(64) # bn1 is followed by conv1
self.cnt=0
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool3d(kernel_size=(2, 3, 3), stride=2, padding=(0,1,1)) # pooling layer for conv1.
self.maxpool_2 = nn.MaxPool3d(kernel_size=(2,1,1),padding=0,stride=(2,1,1)) # pooling layer for res2, 3, 4.
self.layer1 = self._make_layer(block, 64, layers[0], shortcut_type)
self.layer2 = self._make_layer(block, 128, layers[1], shortcut_type, stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], shortcut_type, stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], shortcut_type, stride=2)
self.avgpool = nn.AvgPool2d(kernel_size=(5, 5), stride=1) # pooling layer for res5.
self.dropout=nn.Dropout(p=dropout)
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv3d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm3d):
m.weight.data.fill_(1)
m.bias.data.zero_()
# some private attribute
self.input_size=(self.input_channel,16,160,160) # input of the network
self.input_mean = [0.485, 0.456, 0.406] if modality=='RGB' else [0.5]
self.input_std = [0.229, 0.224, 0.225] if modality=='RGB' else [np.mean([0.229, 0.224, 0.225])]
@property
def scale_size(self):
return self.input_size[2] * 256 // 160 # asume that raw images are resized (340,256).
@property
def temporal_length(self):
return self.input_size[1]
@property
def crop_size(self):
return self.input_size[2]
def _make_layer(self, block, planes, blocks, shortcut_type, stride=1):
downsample = None
stride_p=stride #especially for downsample branch.
if self.cnt<self.depth_3d:
if self.cnt==0:
stride_p=1
else:
stride_p=(1,2,2)
if stride != 1 or self.inplanes != planes * block.expansion:
if shortcut_type == 'A':
downsample = partial(downsample_basic_block,
planes=planes * block.expansion,
stride=stride)
else:
downsample = nn.Sequential(
nn.Conv3d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride_p, bias=False),
nn.BatchNorm3d(planes * block.expansion)
)
else:
if stride != 1 or self.inplanes != planes * block.expansion:
if shortcut_type == 'A':
downsample = partial(downsample_basic_block,
planes=planes * block.expansion,
stride=stride)
else:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=2, bias=False),
nn.BatchNorm2d(planes * block.expansion)
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample,n_s=self.cnt,depth_3d=self.depth_3d,ST_struc=self.ST_struc))
self.cnt+=1
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes,n_s=self.cnt,depth_3d=self.depth_3d,ST_struc=self.ST_struc))
self.cnt+=1
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1_custom(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.maxpool_2(self.layer1(x)) # Part Res2
x = self.maxpool_2(self.layer2(x)) # Part Res3
x = self.maxpool_2(self.layer3(x)) # Part Res4
sizes=x.size()
x = x.view(-1,sizes[1],sizes[3],sizes[4]) # Part Res5
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(-1,self.fc.in_features)
x = self.fc(self.dropout(x))
return x
def P3D63(**kwargs):
"""Construct a P3D63 modelbased on a ResNet-50-3D model.
"""
model = P3D(Bottleneck, [3, 4, 6, 3], **kwargs)
return model
def P3D131(**kwargs):
"""Construct a P3D131 model based on a ResNet-101-3D model.
"""
model = P3D(Bottleneck, [3, 4, 23, 3], **kwargs)
return model
def P3D199(pretrained=False,modality='RGB',**kwargs):
"""construct a P3D199 model based on a ResNet-152-3D model.
"""
model = P3D(Bottleneck, [3, 8, 36, 3], modality=modality,**kwargs)
if pretrained==True:
if modality=='RGB':
pretrained_file='p3d_rgb_199.checkpoint.pth.tar'
elif modality=='Flow':
pretrained_file='p3d_flow_199.checkpoint.pth.tar'
weights=torch.load(pretrained_file)['state_dict']
model.load_state_dict(weights)
return model
# custom operation
def get_optim_policies(model=None,modality='RGB',enable_pbn=True):
'''
first conv: weight --> conv weight
bias --> conv bias
normal action: weight --> non-first conv + fc weight
bias --> non-first conv + fc bias
bn: the first bn2, and many all bn3.
'''
first_conv_weight = []
first_conv_bias = []
normal_weight = []
normal_bias = []
bn = []
if model==None:
#log.l.info('no model!')
exit()
conv_cnt = 0
bn_cnt = 0
for m in model.modules():
if isinstance(m, torch.nn.Conv3d) or isinstance(m, torch.nn.Conv2d):
ps = list(m.parameters())
conv_cnt += 1
if conv_cnt == 1:
first_conv_weight.append(ps[0])
if len(ps) == 2:
first_conv_bias.append(ps[1])
else:
normal_weight.append(ps[0])
if len(ps) == 2:
normal_bias.append(ps[1])
elif isinstance(m, torch.nn.Linear):
ps = list(m.parameters())
normal_weight.append(ps[0])
if len(ps) == 2:
normal_bias.append(ps[1])
elif isinstance(m, torch.nn.BatchNorm3d):
bn_cnt += 1
# later BN's are frozen
if not enable_pbn or bn_cnt == 1:
bn.extend(list(m.parameters()))
elif isinstance(m,torch.nn.BatchNorm2d):
bn.extend(list(m.parameters()))
elif len(m._modules) == 0:
if len(list(m.parameters())) > 0:
raise ValueError("New atomic module type: {}. Need to give it a learning policy".format(type(m)))
slow_rate=0.7
n_fore=int(len(normal_weight)*slow_rate)
slow_feat=normal_weight[:n_fore] # finetune slowly.
slow_bias=normal_bias[:n_fore]
normal_feat=normal_weight[n_fore:]
normal_bias=normal_bias[n_fore:]
return [
{'params': first_conv_weight, 'lr_mult': 5 if modality == 'Flow' else 1, 'decay_mult': 1,
'name': "first_conv_weight"},
{'params': first_conv_bias, 'lr_mult': 10 if modality == 'Flow' else 2, 'decay_mult': 0,
'name': "first_conv_bias"},
{'params': slow_feat, 'lr_mult': 1, 'decay_mult': 1,
'name': "slow_feat"},
{'params': slow_bias, 'lr_mult': 2, 'decay_mult': 0,
'name': "slow_bias"},
{'params': normal_feat, 'lr_mult': 1 , 'decay_mult': 1,
'name': "normal_feat"},
{'params': normal_bias, 'lr_mult': 2, 'decay_mult':0,
'name': "normal_bias"},
{'params': bn, 'lr_mult': 1, 'decay_mult': 0,
'name': "BN scale/shift"},
]
if __name__ == '__main__':
model = P3D199(pretrained=True,num_classes=400)
model = model.cuda()
data=torch.autograd.Variable(torch.rand(10,3,16,160,160)).cuda() # if modality=='Flow', please change the 2nd dimension 3==>2
out=model(data)
print (out.size(),out)