-
Notifications
You must be signed in to change notification settings - Fork 8
/
04_mnist_managed_api.py
121 lines (100 loc) · 3.11 KB
/
04_mnist_managed_api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from curses import keyname
from pathlib import Path
from time import time
import flax.linen as nn
import jax
import jax.numpy as jnp
import optax
import tensorflow as tf
import tensorflow_datasets as tfds
from clu.metrics import Accuracy, Average
import ciclo
from ciclo import managed
print(jax.local_devices())
strategy = ciclo.get_strategy("jit")
batch_size = strategy.lift_batch_size(32)
# load the MNIST dataset
ds_train: tf.data.Dataset = tfds.load("mnist", split="train", shuffle_files=True)
ds_train = ds_train.repeat().shuffle(1024).batch(batch_size).prefetch(1)
ds_valid: tf.data.Dataset = tfds.load("mnist", split="test")
ds_valid = ds_valid.batch(batch_size, drop_remainder=True).prefetch(1)
# Define model
class Linear(nn.Module):
@nn.compact
def __call__(self, x):
x = x / 255.0
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=10)(x)
return x
AverageLoss = Average.from_output("loss")
class ManagedState(managed.ManagedState):
accuracy: Accuracy
loss: AverageLoss
def loss_fn(state: ManagedState, batch):
inputs, labels = batch["image"], batch["label"]
logits = state.apply_fn({"params": state.params}, inputs)
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=labels
).mean()
logs = ciclo.logs()
logs.add_loss("loss", loss)
logs.add_stateful_metric(
"accuracy", Accuracy.from_model_output(logits=logits, labels=labels)
)
logs.add_stateful_metric("loss", AverageLoss.from_model_output(loss=loss))
return logs, state
train_step = managed.train_step(loss_fn)
eval_step = managed.step(loss_fn)
@managed.step
def reset_metrics(state: ManagedState):
return state.replace(
accuracy=Accuracy.empty(),
loss=AverageLoss.empty(),
)
# Initialize state
model = Linear()
variables = model.init(jax.random.PRNGKey(0), jnp.empty((1, 28, 28, 1)))
state = ManagedState.create(
apply_fn=model.apply,
params=variables["params"],
tx=optax.adamw(1e-3),
accuracy=Accuracy.empty(),
loss=AverageLoss.empty(),
strategy=strategy,
)
# training loop
total_samples = 32 * 100
total_steps = total_samples // batch_size
eval_steps = total_steps // 10
state, history, *_ = ciclo.loop(
state,
ds_train.as_numpy_iterator(),
{
**train_step,
ciclo.after(1).every(eval_steps): [
reset_metrics,
ciclo.inner_loop(
"valid",
lambda state: ciclo.loop(
state,
ds_valid.as_numpy_iterator(),
{ciclo.every(steps=1): eval_step},
),
),
ciclo.checkpoint(
f"logdir/{Path(__file__).stem}/{int(time())}",
monitor="accuracy_valid",
mode="max",
keep=3,
),
ciclo.early_stopping(
monitor="accuracy_valid",
mode="max",
patience=eval_steps * 2,
),
reset_metrics,
],
**ciclo.keras_bar(total=total_steps),
},
stop=ciclo.at(samples=total_samples),
)