-
Notifications
You must be signed in to change notification settings - Fork 3
/
utils_model.py
40 lines (32 loc) · 1.24 KB
/
utils_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
"""Based on the logic that only layers that requires grad are modified and thus saved"""
def load_model(model, path):
checkpoint=torch.load(path)
_count_parameters(checkpoint)
model=_load_partial_state_dict(model, checkpoint)
return model
def save_model(model, path):
_count_trainable_parameters(model)
keys = [k for k,v in model.named_parameters() if v.requires_grad]
keys = set(keys)
model_dict = model.state_dict()
model_dict = {k:v for k,v in model_dict.items() if k in keys}
torch.save(model_dict, path)
def load_weights(model, path):
checkpoint=torch.load(path)
model=_load_partial_state_dict(model, checkpoint)
return model
def _count_trainable_parameters(model):
total=sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Num trainable params: ", total)
return total
def _count_parameters(state_dict):
total=sum(p.numel() for p in state_dict.values())
print("Num trainable params: ", total)
return total
def _load_partial_state_dict(model, checkpoint):
"""checkpoint has to totally match a subset of the model state dict"""
model_dict = model.state_dict()
model_dict.update(checkpoint)
model.load_state_dict(model_dict)
return model