Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix SIR predict() use of reparameterizers #2431

Merged
merged 2 commits into from
Apr 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions examples/sir_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,8 +341,8 @@ def heuristic_init(args, data):

def infer_hmc_cont(model, args, data):
if args.dct:
model = poutine.reparam(model, {"S_aux": DiscreteCosineReparam(),
"I_aux": DiscreteCosineReparam()})
rep = DiscreteCosineReparam()
model = poutine.reparam(model, {"S_aux": rep, "I_aux": rep})
init_values = heuristic_init(args, data)
return _infer_hmc(args, data, model, init_values=init_values)

Expand Down Expand Up @@ -478,20 +478,24 @@ def predict(args, data, samples, truth=None):
particle_plate = pyro.plate("particles", args.num_samples, dim=-1)

# First we sample discrete auxiliary variables from the continuous
# variables sampled in vectorized_model. Here infer_discrete runs a
# forward-filter backward-sample algorithm. We'll add these new samples to
# the existing dict of samples.
# variables sampled in vectorized_model. This samples only time steps
# [0:duration]. Here infer_discrete runs a forward-filter backward-sample
# algorithm. We'll add these new samples to the existing dict of samples.
model = poutine.condition(continuous_model, samples)
model = particle_plate(model)
if args.dct: # Apply the same reparameterizer as during inference.
rep = DiscreteCosineReparam()
model = poutine.reparam(model, {"S_aux": rep, "I_aux": rep})
Comment on lines +486 to +488
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the bug: S_aux was being sampled from the prior because the inferred samples only provided S_aux_dct.

model = infer_discrete(model, first_available_dim=-2)
with poutine.trace() as tr:
model(data, args.population)
samples = {name: site["value"]
for name, site in tr.trace.nodes.items()
if site["type"] == "sample"}

# Next we'll run the forward generative process in discrete_model. Again
# we'll update the dict of samples.
# Next we'll run the forward generative process in discrete_model. This
# samples time steps [duration:duration+forecast]. Again we'll update the
# dict of samples.
extended_data = list(data) + [None] * args.forecast
model = poutine.condition(discrete_model, samples)
model = particle_plate(model)
Expand All @@ -501,7 +505,8 @@ def predict(args, data, samples, truth=None):
for name, site in tr.trace.nodes.items()
if site["type"] == "sample"}

# Concatenate sequential time series into tensors.
# Finally we'll concatenate the sequentially sampled values into contiguous
# tensors. This operates on the entire time interval [0:duration+forecast].
for key in ("S", "I", "S2I", "I2R"):
pattern = key + "_[0-9]+"
series = [value
Expand Down
24 changes: 12 additions & 12 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@
'rsa/schelling.py --num-samples=10',
'rsa/schelling_false.py --num-samples=10',
'rsa/semantic_parsing.py --num-samples=10',
'sir_hmc.py -w=2 -n=4 -d=2 -m=1 --enum',
'sir_hmc.py -w=2 -n=4 -d=2 -p=10000 --sequential',
'sir_hmc.py -w=2 -n=4 -d=100 -p=10000 -f 2',
'sir_hmc.py -w=2 -n=4 -d=100 -p=10000 --dct',
'sir_hmc.py -t=2 -w=2 -n=4 -d=2 -m=1 --enum',
'sir_hmc.py -t=2 -w=2 -n=4 -d=2 -p=10000 --sequential',
'sir_hmc.py -t=2 -w=2 -n=4 -d=100 -p=10000 -f 2',
'sir_hmc.py -t=2 -w=2 -n=4 -d=100 -p=10000 -f 2 --dct',
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the regression test: -f fails before this PR

'smcfilter.py --num-timesteps=3 --num-particles=10',
'sparse_gamma_def.py --num-epochs=2 --eval-particles=2 --eval-frequency=1 --guide custom',
'sparse_gamma_def.py --num-epochs=2 --eval-particles=2 --eval-frequency=1 --guide auto',
Expand Down Expand Up @@ -119,10 +119,10 @@
'hmm.py --num-steps=1 --truncate=10 --model=4 --tmc --tmc-num-samples=2 --cuda',
'hmm.py --num-steps=1 --truncate=10 --model=5 --tmc --tmc-num-samples=2 --cuda',
'hmm.py --num-steps=1 --truncate=10 --model=6 --tmc --tmc-num-samples=2 --cuda',
'sir_hmc.py -w=2 -n=4 -d=2 -m=1 --enum --cuda',
'sir_hmc.py -w=2 -n=4 -d=2 -p=10000 --sequential --cuda',
'sir_hmc.py -w=2 -n=4 -d=100 -p=10000 --cuda',
'sir_hmc.py -w=2 -n=4 -d=100 -p=10000 --dct --cuda',
'sir_hmc.py -t=2 -w=2 -n=4 -d=2 -m=1 --enum --cuda',
'sir_hmc.py -t=2 -w=2 -n=4 -d=2 -p=10000 --sequential --cuda',
'sir_hmc.py -t=2 -w=2 -n=4 -d=100 -p=10000 --cuda',
'sir_hmc.py -t=2 -w=2 -n=4 -d=100 -p=10000 --dct --cuda',
'vae/vae.py --num-epochs=1 --cuda',
'vae/ss_vae_M2.py --num-epochs=1 --cuda',
'vae/ss_vae_M2.py --num-epochs=1 --aux-loss --cuda',
Expand Down Expand Up @@ -160,10 +160,10 @@ def xfail_jit(*args):
'lda.py --num-steps=2 --num-words=100 --num-docs=100 --num-words-per-doc=8 --jit',
'minipyro.py --backend=pyro --jit',
'minipyro.py --jit',
'sir_hmc.py -w=2 -n=4 -d=2 -m=1 --enum --jit',
'sir_hmc.py -w=2 -n=4 -d=2 -p=10000 --sequential --jit',
'sir_hmc.py -w=2 -n=4 -p=10000 --jit',
'sir_hmc.py -w=2 -n=4 -p=10000 --dct --jit',
'sir_hmc.py -t=2 -w=2 -n=4 -d=2 -m=1 --enum --jit',
'sir_hmc.py -t=2 -w=2 -n=4 -d=2 -p=10000 --sequential --jit',
'sir_hmc.py -t=2 -w=2 -n=4 -p=10000 --jit',
'sir_hmc.py -t=2 -w=2 -n=4 -p=10000 --dct --jit',
xfail_jit('vae/ss_vae_M2.py --num-epochs=1 --aux-loss --jit'),
'vae/ss_vae_M2.py --num-epochs=1 --enum-discrete=parallel --jit',
'vae/ss_vae_M2.py --num-epochs=1 --enum-discrete=sequential --jit',
Expand Down