-
Notifications
You must be signed in to change notification settings - Fork 9
/
xdc.py
42 lines (33 loc) · 2.18 KB
/
xdc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
import torch.nn as nn
from model.r2plus1d_18 import r2plus1d_18
model_urls = {
"r2plus1d_18_xdc_ig65m_kinetics": "https://github.com/HumamAlwassel/XDC/releases/download/model_weights/r2plus1d_18_xdc_ig65m_kinetics-f24f6ffb.pth",
"r2plus1d_18_xdc_ig65m_random": "https://github.com/HumamAlwassel/XDC/releases/download/model_weights/r2plus1d_18_xdc_ig65m_random-189d23f4.pth",
"r2plus1d_18_xdc_audioset": "https://github.com/HumamAlwassel/XDC/releases/download/model_weights/r2plus1d_18_xdc_audioset-f29ffe8f.pth",
"r2plus1d_18_fs_kinetics": "https://github.com/HumamAlwassel/XDC/releases/download/model_weights/r2plus1d_18_fs_kinetics-622bdad9.pth",
"r2plus1d_18_fs_imagenet": "https://github.com/HumamAlwassel/XDC/releases/download/model_weights/r2plus1d_18_fs_imagenet-ff446670.pth",
}
def xdc_video_encoder(pretraining='r2plus1d_18_xdc_ig65m_kinetics', progress=False, **kwargs):
'''Pretrained video encoders as in
https://arxiv.org/abs/1911.12667
Pretrained weights of all layers except the FC classifier layer are loaded. The FC layer
(of size 512 x num_classes) is randomly-initialized. Specify the keyword argument
`num_classes` based on your application (default is 400).
Args:
pretraining (string): The model pretraining type to load. Available pretrainings are
r2plus1d_18_xdc_ig65m_kinetics: XDC pretrained on IG-Kinetics (default)
r2plus1d_18_xdc_ig65m_random: XDC pretrained on IG-Random
r2plus1d_18_xdc_audioset: XDC pretrained on AudioSet
r2plus1d_18_fs_kinetics: fully-supervised Kinetics-pretrained baseline
r2plus1d_18_fs_imagenet: fully-supervised ImageNet-pretrained baseline
progress (bool): If True, displays a progress bar of the download to stderr
'''
assert pretraining in model_urls, \
f'Unrecognized pretraining type. Available pretrainings: {list(model_urls.keys())}'
model = r2plus1d_18(pretrained=False, progress=progress, **kwargs)
state_dict = torch.hub.load_state_dict_from_url(
model_urls[pretraining], progress=progress, check_hash=True,
)
model.load_state_dict(state_dict, strict=False)
return model