-
Notifications
You must be signed in to change notification settings - Fork 2
/
logger.py
68 lines (58 loc) · 2.22 KB
/
logger.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import matplotlib.pyplot as plt
import numpy as np
class Logger(object):
"""Save training process to log file with simple plot function."""
def __init__(self, fpath, title=None, resume=False):
self.file = None
self.resume = resume
self.title = "" if title == None else title
if fpath is not None:
if resume:
self.file = open(fpath, "r")
name = self.file.readline()
self.names = name.rstrip().split("\t")
self.numbers = {}
for _, name in enumerate(self.names):
self.numbers[name] = []
for numbers in self.file:
numbers = numbers.rstrip().split("\t")
for i in range(0, len(numbers)):
self.numbers[self.names[i]].append(numbers[i])
self.file.close()
self.file = open(fpath, "a")
else:
self.file = open(fpath, "w")
def set_names(self, names):
if self.resume:
pass
# initialize numbers as empty list
self.numbers = {}
self.names = names
for _, name in enumerate(self.names):
self.file.write(name)
self.file.write("\t")
self.numbers[name] = []
self.file.write("\n")
self.file.flush()
def append(self, numbers):
assert len(self.names) == len(numbers), "Numbers do not match names"
for index, num in enumerate(numbers):
if type(num) == float:
self.file.write("{0:.6f}".format(num))
else:
self.file.write(str(num))
self.file.write("\t")
self.numbers[self.names[index]].append(num)
self.file.write("\n")
self.file.flush()
def plot(self, names=None):
names = self.names if names == None else names
numbers = self.numbers
for _, name in enumerate(names):
x = np.arange(len(numbers[name]))
plt.plot(x, np.asarray(numbers[name]))
plt.legend([self.title + "(" + name + ")" for name in names])
plt.grid(True)
def close(self):
if self.file is not None:
self.file.close()