-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_train.py
77 lines (67 loc) · 3.72 KB
/
run_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import argparse
import os
from pprint import pprint
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("Boolean value expected.")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--use_wandb", type=str2bool, default=False, help="log with wandb?")
parser.add_argument("--wandb_project", type=str, default="minimal-LRU", help="wandb project name")
parser.add_argument("--wandb_entity", type=str, help="wandb entity name, e.g. username", )
parser.add_argument("--dir_name", type=str, default="./cache_dir", help="name of directory where data is cached")
parser.add_argument("--dataset", type=str, default="mnist-classification")
parser.add_argument("--jax_seed", type=int, default=1919, help="seed randomness")
# Model Parameters
parser.add_argument("--n_layers", type=int, default=4, help="Number of layers in the network")
parser.add_argument("--d_model", type=int, default=256, help="Number of features")
parser.add_argument("--d_hidden", type=int, default=128, help="Latent size of recurent unit")
parser.add_argument(
"--pooling",
type=str,
default="mean",
choices=["mean", "last", "none"],
help="options: (for classification tasks) \\"
"mean: mean pooling \\"
"last: take last element \\"
"none: no pooling",
)
parser.add_argument("--r_min", type=float, default=0.0, help="|lambda|_min for LRU")
parser.add_argument("--r_max", type=float, default=1.0, help="|lambda|_max for LRU")
parser.add_argument("--norm", type=str, default="none", choices=['batchnorm', 'layernorm', 'rmsnorm', 'none'],
help="Type of normalization")
# Optimization Parameters
parser.add_argument("--batch_size", type=int, default=64, help="Batch size")
parser.add_argument("--epochs", type=int, default=100, help="Max number of epochs")
parser.add_argument("--early_stop_patience", type=int, default=1000,
help="Number of epochs to continue training when val loss plateaus")
parser.add_argument("--lr_base", type=float, default=1e-3, help="Initial learning rate")
parser.add_argument("--lr_min", type=float, default=1e-7, help="Minimum learning rate")
parser.add_argument("--lr_factor", type=float, default=0.5, help="ssm lr = lr_factor * lr_base")
parser.add_argument("--cosine_anneal", type=str2bool, default=True, help="Use cosine annealing")
parser.add_argument("--warmup_end", type=int, default=0, help="When to end linear warmup")
parser.add_argument("--lr_patience", type=int, default=1000000, help="Patience before decaying lr")
parser.add_argument("--reduce_factor", type=float, default=1.0, help="Factor to decay lr")
parser.add_argument("--p_dropout", type=float, default=0.1, help="Probability of dropout")
parser.add_argument("--weight_decay", type=float, default=0.05, help="Weight decay value")
parser.add_argument("--framework", type=str, default='bst', choices=['online', 'flax', 'offline'],
help="Training framework")
parser.add_argument("--device", type=str, default='0', choices=['0', '1', '2', '3'], help="CUDA device")
args = parser.parse_args()
pprint(args)
os.environ['CUDA_VISIBLE_DEVICES'] = args.device
import lru
if args.framework == 'flax':
lru.flax_train(args)
elif args.framework == 'online':
lru.online_train(args)
elif args.framework == 'offline':
lru.offline_train(args)
else:
raise ValueError("Invalid model")