Skip to content

Commit

Permalink
Support torch.jit in ELBO implementations (#1109)
Browse files Browse the repository at this point in the history
* First attempt at jitting loss_and_grads

* Add grad()-based implementation if jit_loss_and_grads

* Fix perf bug in jit_loss_and_grads; add xfailing tests

* Fix memory leak in jit_loss_and_grads

* Support jit in TraceGraph_ELBO

* Add xfailing tests for jit_loss_and_grads value

* Implement Trace_ELBO.jit_loss_and_grads

* Refactor to create JitTrace*_ELBO classes
  • Loading branch information
fritzo authored and eb8680 committed May 5, 2018
1 parent bf90bf8 commit 0a138ff
Show file tree
Hide file tree
Showing 9 changed files with 505 additions and 24 deletions.
4 changes: 4 additions & 0 deletions docs/source/inference_algos.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,25 @@ ELBO
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

.. automodule:: pyro.infer.trace_elbo
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

.. automodule:: pyro.infer.tracegraph_elbo
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

.. automodule:: pyro.infer.traceenum_elbo
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

Importance
----------
Expand Down
17 changes: 10 additions & 7 deletions pyro/infer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,25 @@
from pyro.infer.enum import config_enumerate
from pyro.infer.importance import Importance
from pyro.infer.svi import SVI
from pyro.infer.trace_elbo import Trace_ELBO
from pyro.infer.traceenum_elbo import TraceEnum_ELBO
from pyro.infer.tracegraph_elbo import TraceGraph_ELBO
from pyro.infer.trace_elbo import JitTrace_ELBO, Trace_ELBO
from pyro.infer.traceenum_elbo import JitTraceEnum_ELBO, TraceEnum_ELBO
from pyro.infer.tracegraph_elbo import JitTraceGraph_ELBO, TraceGraph_ELBO
from pyro.infer.util import enable_validation, is_validation_enabled

__all__ = [
"config_enumerate",
"enable_validation",
"is_validation_enabled",
"ELBO",
"Importance",
"EmpiricalMarginal",
"TracePredictive",
"Importance",
"JitTraceEnum_ELBO",
"JitTraceGraph_ELBO",
"JitTrace_ELBO",
"SVI",
"TracePosterior",
"Trace_ELBO",
"TraceEnum_ELBO",
"TraceGraph_ELBO",
"TracePosterior",
"TracePredictive",
"Trace_ELBO",
]
82 changes: 82 additions & 0 deletions pyro/infer/trace_elbo.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from __future__ import absolute_import, division, print_function

import warnings
import weakref

import torch

import pyro
import pyro.poutine as poutine
from pyro.distributions.util import is_identically_zero
from pyro.infer.elbo import ELBO
Expand Down Expand Up @@ -142,3 +146,81 @@ def loss_and_grads(self, model, guide, *args, **kwargs):
if torch_isnan(loss):
warnings.warn('Encountered NAN loss')
return loss


class JitTrace_ELBO(Trace_ELBO):
"""
Like :class:`Trace_ELBO` but uses :func:`torch.jit.compile` to compile
:meth:`loss_and_grads`.
This works only for a limited set of models:
- Models must have static structure.
- Models must not depend on any global data (except the param store).
- All model inputs that are tensors must be passed in via ``*args``.
- All model inputs that are *not* tensors must be passed in via
``*kwargs``, and these will be fixed to their values on the first
call to :meth:`jit_loss_and_grads`.
.. warning:: Experimental. Interface subject to change.
"""
def loss_and_grads(self, model, guide, *args, **kwargs):
if getattr(self, '_loss_and_surrogate_loss', None) is None:
# populate param store
with poutine.block():
with poutine.trace(param_only=True) as param_capture:
for _ in self._get_traces(model, guide, *args, **kwargs):
pass
self._param_names = list(param_capture.trace.nodes.keys())

# build a closure for loss_and_surrogate_loss
weakself = weakref.ref(self)

@torch.jit.compile(nderivs=1)
def loss_and_surrogate_loss(args_list, param_list):
self = weakself()
loss = 0.0
surrogate_loss = 0.0
for model_trace, guide_trace in self._get_traces(model, guide, *args_list, **kwargs):
elbo_particle = 0
surrogate_elbo_particle = 0
log_r = None

# compute elbo and surrogate elbo
for name, site in model_trace.nodes.items():
if site["type"] == "sample":
elbo_particle = elbo_particle + site["log_prob_sum"]
surrogate_elbo_particle = surrogate_elbo_particle + site["log_prob_sum"]

for name, site in guide_trace.nodes.items():
if site["type"] == "sample":
log_prob, score_function_term, entropy_term = site["score_parts"]

elbo_particle = elbo_particle - site["log_prob_sum"]

if not is_identically_zero(entropy_term):
surrogate_elbo_particle = surrogate_elbo_particle - entropy_term.sum()

if not is_identically_zero(score_function_term):
if log_r is None:
log_r = _compute_log_r(model_trace, guide_trace)
site = log_r.sum_to(site["cond_indep_stack"])
surrogate_elbo_particle = surrogate_elbo_particle + (site * score_function_term).sum()

loss = loss - elbo_particle / self.num_particles
surrogate_loss = surrogate_loss - surrogate_elbo_particle / self.num_particles

return loss, surrogate_loss

self._loss_and_surrogate_loss = loss_and_surrogate_loss

# invoke _loss_and_surrogate_loss
args_list = list(args)
param_list = [pyro.param(name).unconstrained() for name in self._param_names]
loss, surrogate_loss = self._loss_and_surrogate_loss(args_list, param_list)
surrogate_loss.backward() # this line triggers jit compilation
loss = loss.item()

if torch_isnan(loss):
warnings.warn('Encountered NAN loss')
return loss
54 changes: 54 additions & 0 deletions pyro/infer/traceenum_elbo.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from __future__ import absolute_import, division, print_function

import warnings
import weakref

import torch

import pyro
import pyro.poutine as poutine
from pyro.distributions.util import is_identically_zero
from pyro.infer.elbo import ELBO
Expand Down Expand Up @@ -147,3 +151,53 @@ def loss_and_grads(self, model, guide, *args, **kwargs):
if torch_isnan(loss):
warnings.warn('Encountered NAN loss')
return loss


class JitTraceEnum_ELBO(TraceEnum_ELBO):
"""
Like :class:`TraceEnum_ELBO` but uses :func:`torch.jit.compile` to
compile :meth:`loss_and_grads`.
This works only for a limited set of models:
- Models must have static structure.
- Models must not depend on any global data (except the param store).
- All model inputs that are tensors must be passed in via ``*args``.
- All model inputs that are *not* tensors must be passed in via
``*kwargs``, and these will be fixed to their values on the first
call to :meth:`jit_loss_and_grads`.
.. warning:: Experimental. Interface subject to change.
"""
def loss_and_grads(self, model, guide, *args, **kwargs):
if getattr(self, '_differentiable_loss', None) is None:
# populate param store
with poutine.block():
with poutine.trace(param_only=True) as param_capture:
for _ in self._get_traces(model, guide, *args, **kwargs):
pass
self._param_names = list(param_capture.trace.nodes.keys())

# build a closure for differentiable_loss
weakself = weakref.ref(self)

@torch.jit.compile(nderivs=1)
def differentiable_loss(args_list, param_list):
self = weakself()
elbo = 0.0
for model_trace, guide_trace in self._get_traces(model, guide, *args_list, **kwargs):
elbo += _compute_dice_elbo(model_trace, guide_trace)
return elbo * (-1.0 / self.num_particles)

self._differentiable_loss = differentiable_loss

# invoke _differentiable_loss
args_list = list(args)
param_list = [pyro.param(name).unconstrained() for name in self._param_names]
differentiable_loss = self._differentiable_loss(args_list, param_list)
differentiable_loss.backward() # this line triggers jit compilation
loss = differentiable_loss.item()

if torch_isnan(loss):
warnings.warn('Encountered NAN loss')
return loss
84 changes: 81 additions & 3 deletions pyro/infer/tracegraph_elbo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import absolute_import, division, print_function

import warnings
import weakref
from operator import itemgetter

import networkx
Expand Down Expand Up @@ -99,13 +100,13 @@ def _compute_elbo_reparam(model_trace, guide_trace, non_reparam_nodes):
# deal with log p(z|...) terms
for name, site in model_trace.nodes.items():
if site["type"] == "sample":
elbo += torch_item(site["log_prob_sum"])
elbo += site["log_prob_sum"]
surrogate_elbo += site["log_prob_sum"]

# deal with log q(z|...) terms
for name, site in guide_trace.nodes.items():
if site["type"] == "sample":
elbo -= torch_item(site["log_prob_sum"])
elbo -= site["log_prob_sum"]
entropy_term = site["score_parts"].entropy_term
if not is_identically_zero(entropy_term):
surrogate_elbo -= entropy_term.sum()
Expand Down Expand Up @@ -277,7 +278,84 @@ def _loss_and_grads_particle(self, weight, model_trace, guide_trace):
surrogate_loss = -surrogate_elbo
torch_backward(weight * (surrogate_loss + baseline_loss))

loss = -elbo
loss = -torch_item(elbo)
if torch_isnan(loss):
warnings.warn('Encountered NAN loss')
return weight * loss


class JitTraceGraph_ELBO(TraceGraph_ELBO):
"""
Like :class:`TraceGraph_ELBO` but uses :func:`torch.jit.compile` to
compile :meth:`loss_and_grads`.
This works only for a limited set of models:
- Models must have static structure.
- Models must not depend on any global data (except the param store).
- All model inputs that are tensors must be passed in via ``*args``.
- All model inputs that are *not* tensors must be passed in via
``*kwargs``, and these will be fixed to their values on the first
call to :meth:`loss_and_grads`.
.. warning:: Experimental. Interface subject to change.
"""

def loss_and_grads(self, model, guide, *args, **kwargs):
if getattr(self, '_loss_and_surrogate_loss', None) is None:
# populate param store
with poutine.block():
with poutine.trace(param_only=True) as param_capture:
for _ in self._get_traces(model, guide, *args, **kwargs):
pass
self._param_names = list(param_capture.trace.nodes.keys())

# build a closure for loss_and_surrogate_loss
weakself = weakref.ref(self)

@torch.jit.compile(nderivs=1)
def loss_and_surrogate_loss(args_list, param_list):
self = weakself()
loss = 0.0
surrogate_loss = 0.0
for weight, model_trace, guide_trace in self._get_traces(model, guide, *args_list, **kwargs):
model_trace.compute_log_prob()
guide_trace.compute_score_parts()
if is_validation_enabled():
for site in model_trace.nodes.values():
if site["type"] == "sample":
check_site_shape(site, self.max_iarange_nesting)
for site in guide_trace.nodes.values():
if site["type"] == "sample":
check_site_shape(site, self.max_iarange_nesting)

# compute elbo for reparameterized nodes
non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)
elbo, surrogate_elbo = _compute_elbo_reparam(model_trace, guide_trace, non_reparam_nodes)

# the following computations are only necessary if we have non-reparameterizable nodes
baseline_loss = 0.0
if non_reparam_nodes:
downstream_costs, _ = _compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes)
surrogate_elbo_term, baseline_loss = _compute_elbo_non_reparam(guide_trace,
non_reparam_nodes,
downstream_costs)
surrogate_elbo += surrogate_elbo_term

loss = loss - weight * elbo
surrogate_loss = surrogate_loss - weight * surrogate_elbo

return loss, surrogate_loss

self._loss_and_surrogate_loss = loss_and_surrogate_loss

# invoke _loss_and_surrogate_loss
args_list = list(args)
param_list = [pyro.param(name).unconstrained() for name in self._param_names]
loss, surrogate_loss = self._loss_and_surrogate_loss(args_list, param_list)
surrogate_loss.backward() # this line triggers jit compilation
loss = loss.item()

if torch_isnan(loss):
warnings.warn('Encountered NAN loss')
return loss
29 changes: 22 additions & 7 deletions tests/infer/test_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
import pyro
import pyro.distributions as dist
from pyro.distributions.testing import fakes
from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, TraceGraph_ELBO
from pyro.infer import (SVI, JitTrace_ELBO, JitTraceEnum_ELBO, JitTraceGraph_ELBO, Trace_ELBO, TraceEnum_ELBO,
TraceGraph_ELBO)
from pyro.optim import Adam
from tests.common import assert_equal
from tests.common import assert_equal, xfail_param

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -112,7 +113,17 @@ def guide():

@pytest.mark.parametrize("reparameterized", [True, False], ids=["reparam", "nonreparam"])
@pytest.mark.parametrize("subsample", [False, True], ids=["full", "subsample"])
@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO])
@pytest.mark.parametrize("Elbo", [
Trace_ELBO,
TraceGraph_ELBO,
TraceEnum_ELBO,
xfail_param(JitTrace_ELBO,
reason="jit RuntimeError: Unsupported op descriptor: index-2"),
xfail_param(JitTraceGraph_ELBO,
reason="jit RuntimeError: Unsupported op descriptor: index-2"),
xfail_param(JitTraceEnum_ELBO,
reason="jit RuntimeError: Unsupported op descriptor: index-2"),
])
def test_subsample_gradient_sequential(Elbo, reparameterized, subsample):
pyro.clear_param_store()
data = torch.tensor([-0.5, 2.0])
Expand All @@ -134,11 +145,15 @@ def guide():
pyro.sample("z", Normal(loc[ind], scale))

optim = Adam({"lr": 0.1})
elbo = Elbo(num_particles=num_particles, strict_enumeration_warning=False)
inference = SVI(model, guide, optim, loss=elbo)
inference.loss_and_grads(model, guide)
elbo = Elbo(num_particles=10, strict_enumeration_warning=False)
inference = SVI(model, guide, optim, elbo)
iters = num_particles // 10
for _ in range(iters):
inference.loss_and_grads(model, guide)

params = dict(pyro.get_param_store().named_parameters())
actual_grads = {name: param.grad.detach().cpu().numpy() for name, param in params.items()}
actual_grads = {name: param.grad.detach().cpu().numpy() / iters
for name, param in params.items()}

expected_grads = {'loc': np.array([0.5, -2.0]), 'scale': np.array([2.0])}
for name in sorted(params):
Expand Down
Loading

0 comments on commit 0a138ff

Please sign in to comment.