-
Notifications
You must be signed in to change notification settings - Fork 5
/
hubconf.py
78 lines (67 loc) · 3.04 KB
/
hubconf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""
===============================================================================
Author: Anjith George
Institution: Idiap Research Institute, Martigny, Switzerland.
Copyright (C) 2024 Anjith George
This software is distributed under the terms described in the LICENSE file
located in the parent directory of this source code repository.
For inquiries, please contact the author at anjith.george@idiap.ch
===============================================================================
"""
dependencies = ['torch', 'torchvision', 'timm']
from backbones import get_model
import torch
def edgeface_base(pretrained=True, **kwargs):
model = get_model('edgeface_base', **kwargs)
if pretrained:
checkpoint_url = 'https://gitlab.idiap.ch/bob/bob.paper.tbiom2023_edgeface/-/raw/master/checkpoints/edgeface_base.pt'
state_dict = torch.hub.load_state_dict_from_url(
checkpoint_url, map_location='cpu'
)
model.load_state_dict(state_dict)
return model
def edgeface_xs_gamma_06(pretrained=True, **kwargs):
model = get_model('edgeface_xs_gamma_06', **kwargs)
if pretrained:
checkpoint_url = 'https://gitlab.idiap.ch/bob/bob.paper.tbiom2023_edgeface/-/raw/master/checkpoints/edgeface_xs_gamma_06.pt'
state_dict = torch.hub.load_state_dict_from_url(
checkpoint_url, map_location='cpu'
)
model.load_state_dict(state_dict)
return model
def edgeface_xs_q(pretrained=True, **kwargs):
model = get_model('edgeface_xs_q', **kwargs)
if pretrained:
checkpoint_url = 'https://gitlab.idiap.ch/bob/bob.paper.tbiom2023_edgeface/-/raw/master/checkpoints/edgeface_xs_q.pt'
state_dict = torch.hub.load_state_dict_from_url(
checkpoint_url, map_location='cpu'
)
model.load_state_dict(state_dict)
return model
def edgeface_xxs(pretrained=True, **kwargs):
model = get_model('edgeface_xxs', **kwargs)
if pretrained:
checkpoint_url = 'https://gitlab.idiap.ch/bob/bob.paper.tbiom2023_edgeface/-/raw/master/checkpoints/edgeface_xxs.pt'
state_dict = torch.hub.load_state_dict_from_url(
checkpoint_url, map_location='cpu'
)
model.load_state_dict(state_dict)
return model
def edgeface_xxs_q(pretrained=True, **kwargs):
model = get_model('edgeface_xxs_q', **kwargs)
if pretrained:
checkpoint_url = 'https://gitlab.idiap.ch/bob/bob.paper.tbiom2023_edgeface/-/raw/master/checkpoints/edgeface_xxs_q.pt'
state_dict = torch.hub.load_state_dict_from_url(
checkpoint_url, map_location='cpu'
)
model.load_state_dict(state_dict)
return model
def edgeface_s_gamma_05(pretrained=True, **kwargs):
model = get_model('edgeface_s_gamma_05', **kwargs)
if pretrained:
checkpoint_url = 'https://gitlab.idiap.ch/bob/bob.paper.tbiom2023_edgeface/-/raw/master/checkpoints/edgeface_s_gamma_05.pt'
state_dict = torch.hub.load_state_dict_from_url(
checkpoint_url, map_location='cpu'
)
model.load_state_dict(state_dict)
return model