This repository has been archived by the owner on Jul 20, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 13
/
base_model.py
61 lines (47 loc) · 1.71 KB
/
base_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
"""
BasicSR/codes/models/modules/LPIPS/base_model.py (6-Dez-20)
https://github.com/victorca25/BasicSR/blob/master/codes/models/modules/LPIPS/base_model.py
"""
import os
import torch
from torch.autograd import Variable
#from pdb import set_trace as st
from IPython import embed
class BaseModel():
def __init__(self):
pass
def name(self):
return 'BaseModel'
def initialize(self, use_gpu=True, gpu_ids=[0]):
self.use_gpu = use_gpu
self.gpu_ids = gpu_ids
def forward(self):
pass
def get_image_paths(self):
pass
def optimize_parameters(self):
pass
def get_current_visuals(self):
return self.input
def get_current_errors(self):
return {}
def save(self, label):
pass
# helper saving function that can be used by subclasses
def save_network(self, network, path, network_label, epoch_label):
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
save_path = os.path.join(path, save_filename)
torch.save(network.state_dict(), save_path)
# helper loading function that can be used by subclasses
def load_network(self, network, network_label, epoch_label):
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
save_path = os.path.join(self.save_dir, save_filename)
print('Loading network from %s'%save_path)
network.load_state_dict(torch.load(save_path))
def update_learning_rate():
pass
def get_image_paths(self):
return self.image_paths
def save_done(self, flag=False):
np.save(os.path.join(self.save_dir, 'done_flag'),flag)
np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i')