Skip to content

Commit

Permalink
Fix LocScaleReparam, log params in backtest() (#2365)
Browse files Browse the repository at this point in the history
* Fix name bug in LocScaleReparam

* Save scalar params in backtest()
  • Loading branch information
fritzo authored Mar 13, 2020
1 parent fe3fc2a commit 7e20031
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
11 changes: 9 additions & 2 deletions pyro/contrib/forecast/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,12 +200,19 @@ def forecaster_options_fn(*args, **kwargs):
"num_samples": num_samples,
"train_walltime": train_walltime,
"test_walltime": test_walltime,
"params": {},
}
results.append(result)
for name, fn in metrics.items():
result[name] = fn(pred, truth)
if isinstance(result[name], (int, float)):
logger.debug("{} = {}".format(name, result[name]))
for name, value in pyro.get_param_store().items():
if value.numel() == 1:
value = value.cpu().item()
result["params"][name] = value
for dct in (result, result["params"]):
for key, value in sorted(dct.items()):
if isinstance(value, (int, float)):
logger.debug("{} = {:0.6g}".format(key, value))

del pred

Expand Down
2 changes: 1 addition & 1 deletion pyro/infer/reparam/loc_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __call__(self, name, fn, obs):
# Apply a partial decentering transform.
params = {key: getattr(fn, key) for key in self.shape_params}
if self.centered is None:
centered = pyro.param("{}_centered",
centered = pyro.param("{}_centered".format(name),
lambda: fn.loc.new_full(event_shape, 0.5),
constraint=constraints.unit_interval)
params["loc"] = fn.loc * centered
Expand Down

0 comments on commit 7e20031

Please sign in to comment.