-
Notifications
You must be signed in to change notification settings - Fork 8
/
main.py
73 lines (57 loc) · 2.47 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import argparse
import matplotlib.pyplot as plt
import numpyro
from numpyro.infer import NUTS, MCMC, init_to_median
from jax import random
import arviz as az
numpyro.set_host_device_count(4)
from lqg import tracking
from lqg.infer.models import lifted_model
from lqg.infer.utils import sample_from_prior
def parse_args():
parser = argparse.ArgumentParser(description="Coverage runs")
parser.add_argument("--ntrial", type=int, default=20,
help="Number of trials .")
parser.add_argument("--seed", type=int, default=42, help="Seed for the simulation")
parser.add_argument("--time", type=int, default=720,
help="Time steps per trial")
parser.add_argument("--nsamp", type=int, default=5_000,
help="Number of samples drawn by NUTS")
parser.add_argument("--nwarmup", type=int, default=2_500,
help="Number of burn-in samples.")
parser.add_argument("--nchain", type=int, default=4,
help="Number of chains.")
parser.add_argument("--model", type=str, default="BoundedActor",
help="Model type (lqg.tracking)")
parser.add_argument('--plot', action=argparse.BooleanOptionalAction)
parser.add_argument('--save', action=argparse.BooleanOptionalAction)
return parser.parse_args()
if __name__ == '__main__':
args = parse_args()
Model = getattr(tracking, args.model)
params = sample_from_prior(Model, args.seed)
# setup model and simulate data
model = Model(T=args.time, **params)
x = model.simulate(random.PRNGKey(args.seed), n=args.ntrial)
if args.plot:
# visualize trajectories
plt.plot(x[0, :, 0])
plt.plot(x[0, :, 1])
plt.xlabel("time")
plt.ylabel("position")
plt.show()
nuts_kernel = NUTS(lifted_model, init_strategy=init_to_median)
mcmc = MCMC(nuts_kernel, num_warmup=args.nwarmup, num_samples=args.nsamp,
num_chains=args.nchain)
mcmc.run(random.PRNGKey(args.seed), x, Model)
idata = az.convert_to_inference_data(mcmc)
if args.plot:
az.plot_pair(idata, reference_values=params, figsize=(6, 6), kind="hexbin")
plt.show()
if args.save:
summary = az.summary(idata)
for key in params:
summary.loc[key, "true"] = params[key]
summary[key] = params[key]
summary["seed"] = args.seed
summary.to_csv(f"results/parameter-recovery/{args.model}-{args.seed}.csv")