-
Notifications
You must be signed in to change notification settings - Fork 0
/
q5_1_TRANS.py
86 lines (70 loc) · 2.7 KB
/
q5_1_TRANS.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
'''
File for 5.1
Computing the average loss at each timestep over the whole validation set.
'''
import argparse
import os
import numpy as np
import torch
import torch.nn as nn
from utils import Experiment, prepare_data, repackage_hidden, ptb_iterator, Batch
if __name__ == '__main__':
# Parse arguments
parser = argparse.ArgumentParser(description='Compute average loss over validation for a given model.')
parser.add_argument('exp_dir', type=str,
help='The directory from the experiment run for the model you want.')
args = parser.parse_args()
# Use the GPU if you have one
if torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
# Set up the model
print('Loading model...')
exp = Experiment(args.exp_dir)
model = exp.load_model()
model = model.to(device)
print(' Model loaded.')
if exp.config['model'] == 'TRANSFORMER':
batch_size = 128
seq_len = 35
else:
batch_size = model.batch_size
seq_len = model.seq_len
num_batches = 0
# Set up the validation data
_, valid_data, _, _ = prepare_data()
# Set up loss function
loss_fn = nn.CrossEntropyLoss(reduction='none')
# Initialize hidden state
if exp.config['model'] != 'TRANSFORMER':
hidden = model.init_hidden()
hidden = hidden.to(device)
# Loop through validation set (one epoch)
mean_losses = torch.zeros(seq_len, device=device)
for step, (x, y) in enumerate(ptb_iterator(valid_data, batch_size, seq_len)):
print('Step {}'.format(step))
num_batches += 1
batch = Batch(torch.from_numpy(x).long().to(device))
model.zero_grad()
# On David's recommendation
# hidden = model.init_hidden()
# Pass through network
inputs = torch.from_numpy(x.astype(np.int64)).transpose(0, 1).contiguous().to(device)
model.zero_grad()
# hidden = repackage_hidden(hidden)
# outputs, hidden = model(inputs, hidden)
outputs = model.forward(batch.data, batch.mask).transpose(1, 0)
targets = torch.from_numpy(y.astype(np.int64)).contiguous().to(device)
tt = torch.squeeze(targets)
# Calculate loss
# Collect per-timestep losses from logits across all examples
loss = loss_fn(outputs.contiguous().permute(1, 2, 0), tt)
# Take mean "down the column" of all t=1, t=2, ...
mean_loss = loss.detach().mean(0)
# Keep running sum
mean_losses += mean_loss
# On David's formulation
final_losses = (mean_losses / num_batches).detach().cpu().numpy()
# Log/save
np.save(os.path.join('.', 'avg_5_1.npy'), final_losses)