Skip to content

Commit

Permalink
Implement AutoGuideList.quantiles() (#2896)
Browse files Browse the repository at this point in the history
* added quantile method to AutoGuideList

* added AutoGuideList quantile tests

* fixed lint

* updated doc format

Co-authored-by: Vitalii Kleshchevnikov <vk7@sanger.ac.uk>
Co-authored-by: Vitalii Kleshchevnikov <klv2706w@gmail.com>
  • Loading branch information
3 people authored Jul 6, 2021
1 parent 21f6716 commit 7cf026f
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 0 deletions.
13 changes: 13 additions & 0 deletions pyro/infer/autoguide/guides.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,19 @@ def median(self, *args, **kwargs):
result.update(part.median(*args, **kwargs))
return result

def quantiles(self, quantiles, *args, **kwargs):
"""
Returns the posterior quantile values of each latent variable.
:param list quantiles: A list of requested quantiles between 0 and 1.
:returns: A dict mapping sample site name to quantiles tensor.
:rtype: dict
"""
result = {}
for part in self:
result.update(part.quantiles(quantiles, *args, **kwargs))
return result


class AutoCallable(AutoGuide):
"""
Expand Down
8 changes: 8 additions & 0 deletions tests/infer/test_autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,13 @@ def model():
assert_equal(attr_get(guide_deser), attr_get(guide).data)


def AutoGuideList_x(model):
guide = AutoGuideList(model)
guide.append(AutoNormal(poutine.block(model, expose=["x"])))
guide.append(AutoLowRankMultivariateNormal(poutine.block(model, hide=["x"])))
return guide


@pytest.mark.parametrize(
"auto_class",
[
Expand All @@ -430,6 +437,7 @@ def model():
AutoNormal,
AutoLowRankMultivariateNormal,
AutoLaplaceApproximation,
AutoGuideList_x,
],
)
@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO])
Expand Down

0 comments on commit 7cf026f

Please sign in to comment.