-
Notifications
You must be signed in to change notification settings - Fork 0
/
util.py
81 lines (55 loc) · 1.94 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import jax.numpy as jnp
from jax import pmap, host_id, jit
from jax.tree_util import tree_map
from jax.nn import one_hot, log_softmax
import datetime
import os
import pickle
def shard(x):
# pmap x across first axis
return pmap(lambda v: v)(x)
def replicate(x, replicas=8):
# replicate leafs of x and then shard
replicated = tree_map(lambda v: jnp.stack([v] * replicas), x)
return shard(replicated)
def shapes_of(pytree):
# rebuild a pytree swapping actual params for just shape and type
return tree_map(lambda v: (v.shape, type(v), v.dtype), pytree)
def DTS():
return datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
def ensure_dir_exists(directory):
if not os.path.exists(directory):
os.makedirs(directory)
def ensure_dir_exists_for_file(fname):
ensure_dir_exists(os.path.dirname(fname))
def primary_host():
return host_id() == 0
def softmax_cross_entropy(logits, labels):
one_hot_labels = one_hot(labels, logits.shape[-1])
return -jnp.sum(log_softmax(logits) * one_hot_labels, axis=-1)
def accuracy_mean_loss(calc_logits_fn, dataset):
num_correct = 0
total_loss = 0
num_total = 0
@jit
def predict_with_losses(x, y_true):
logits = calc_logits_fn(x)
losses = softmax_cross_entropy(logits, y_true)
return jnp.argmax(logits, axis=-1), losses
for x, y_true in dataset:
y_pred, losses = predict_with_losses(x, y_true)
num_correct += jnp.sum(y_true == y_pred)
total_loss += jnp.sum(losses)
num_total += len(y_true)
accuracy = float(num_correct / num_total)
mean_loss = float(total_loss / num_total)
return accuracy, mean_loss
def save_params(run, epoch, params):
fname = f"params/{run}/{epoch}.pkl"
ensure_dir_exists_for_file(fname)
with open(fname, 'wb') as f:
pickle.dump(params, f)
def load_params(fname):
with open(fname, 'rb') as f:
params = pickle.load(f)
return params