-
Notifications
You must be signed in to change notification settings - Fork 909
/
main.py
104 lines (82 loc) · 2.73 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# Copyright © 2023 Apple Inc.
import argparse
import time
from functools import partial
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np
import mnist
class MLP(nn.Module):
"""A simple MLP."""
def __init__(
self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
):
super().__init__()
layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
self.layers = [
nn.Linear(idim, odim)
for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
]
def __call__(self, x):
for l in self.layers[:-1]:
x = nn.relu(l(x))
return self.layers[-1](x)
def loss_fn(model, X, y):
return nn.losses.cross_entropy(model(X), y, reduction="mean")
def batch_iterate(batch_size, X, y):
perm = mx.array(np.random.permutation(y.size))
for s in range(0, y.size, batch_size):
ids = perm[s : s + batch_size]
yield X[ids], y[ids]
def main(args):
seed = 0
num_layers = 2
hidden_dim = 32
num_classes = 10
batch_size = 256
num_epochs = 10
learning_rate = 1e-1
np.random.seed(seed)
# Load the data
train_images, train_labels, test_images, test_labels = map(
mx.array, getattr(mnist, args.dataset)()
)
# Load the model
model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes)
mx.eval(model.parameters())
optimizer = optim.SGD(learning_rate=learning_rate)
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
@partial(mx.compile, inputs=model.state, outputs=model.state)
def step(X, y):
loss, grads = loss_and_grad_fn(model, X, y)
optimizer.update(model, grads)
return loss
@partial(mx.compile, inputs=model.state)
def eval_fn(X, y):
return mx.mean(mx.argmax(model(X), axis=1) == y)
for e in range(num_epochs):
tic = time.perf_counter()
for X, y in batch_iterate(batch_size, train_images, train_labels):
step(X, y)
mx.eval(model.state)
accuracy = eval_fn(test_images, test_labels)
toc = time.perf_counter()
print(
f"Epoch {e}: Test accuracy {accuracy.item():.3f},"
f" Time {toc - tic:.3f} (s)"
)
if __name__ == "__main__":
parser = argparse.ArgumentParser("Train a simple MLP on MNIST with MLX.")
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
parser.add_argument(
"--dataset",
type=str,
default="mnist",
choices=["mnist", "fashion_mnist"],
help="The dataset to use.",
)
args = parser.parse_args()
if not args.gpu:
mx.set_default_device(mx.cpu)
main(args)