-
Notifications
You must be signed in to change notification settings - Fork 63
/
utils.py
39 lines (32 loc) · 1.26 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import numpy as np
### Utility function and class
class EarlyStopMonitor(object):
def __init__(self, max_round=3, higher_better=True, tolerance=1e-3):
self.max_round = max_round
self.num_round = 0
self.epoch_count = 0
self.best_epoch = 0
self.last_best = None
self.higher_better = higher_better
self.tolerance = tolerance
def early_stop_check(self, curr_val):
self.epoch_count += 1
if not self.higher_better:
curr_val *= -1
if self.last_best is None:
self.last_best = curr_val
elif (curr_val - self.last_best) / np.abs(self.last_best) > self.tolerance:
self.last_best = curr_val
self.num_round = 0
self.best_epoch = self.epoch_count
else:
self.num_round += 1
return self.num_round >= self.max_round
class RandEdgeSampler(object):
def __init__(self, src_list, dst_list):
self.src_list = np.unique(src_list)
self.dst_list = np.unique(dst_list)
def sample(self, size):
src_index = np.random.randint(0, len(self.src_list), size)
dst_index = np.random.randint(0, len(self.dst_list), size)
return self.src_list[src_index], self.dst_list[dst_index]