-
Notifications
You must be signed in to change notification settings - Fork 6
/
misc.py
66 lines (58 loc) · 1.83 KB
/
misc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
torch_dtypes = {
'float': torch.float,
'float32': torch.float32,
'float64': torch.float64,
'double': torch.double,
'float16': torch.float16,
'half': torch.half,
'uint8': torch.uint8,
'int8': torch.int8,
'int16': torch.int16,
'short': torch.short,
'int32': torch.int32,
'int': torch.int,
'int64': torch.int64,
'long': torch.long
}
def onehot(indexes, N=None, ignore_index=None):
"""
Creates a one-representation of indexes with N possible entries
if N is not specified, it will suit the maximum index appearing.
indexes is a long-tensor of indexes
ignore_index will be zero in onehot representation
"""
if N is None:
N = indexes.max() + 1
sz = list(indexes.size())
output = indexes.new().byte().resize_(*sz, N).zero_()
output.scatter_(-1, indexes.unsqueeze(-1), 1)
if ignore_index is not None and ignore_index >= 0:
output.masked_fill_(indexes.eq(ignore_index).unsqueeze(-1), 0)
return output
def set_global_seeds(i):
try:
import torch
except ImportError:
pass
else:
torch.manual_seed(i)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(i)
np.random.seed(i)
random.seed(i)
class CheckpointModule(nn.Module):
def __init__(self, module, num_segments=1):
super(CheckpointModule, self).__init__()
assert num_segments == 1 or isinstance(module, nn.Sequential)
self.module = module
self.num_segments = num_segments
def forward(self, x):
if self.num_segments > 1:
return checkpoint_sequential(self.module, self.num_segments, x)
else:
return checkpoint(self.module, x)