Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support torch.jit in ELBO implementations #1109

Merged
merged 9 commits into from
May 5, 2018
4 changes: 4 additions & 0 deletions docs/source/inference_algos.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,25 @@ ELBO
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

.. automodule:: pyro.infer.trace_elbo
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

.. automodule:: pyro.infer.tracegraph_elbo
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

.. automodule:: pyro.infer.traceenum_elbo
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

Importance
----------
Expand Down
17 changes: 10 additions & 7 deletions pyro/infer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,25 @@
from pyro.infer.enum import config_enumerate
from pyro.infer.importance import Importance
from pyro.infer.svi import SVI
from pyro.infer.trace_elbo import Trace_ELBO
from pyro.infer.traceenum_elbo import TraceEnum_ELBO
from pyro.infer.tracegraph_elbo import TraceGraph_ELBO
from pyro.infer.trace_elbo import JitTrace_ELBO, Trace_ELBO
from pyro.infer.traceenum_elbo import JitTraceEnum_ELBO, TraceEnum_ELBO
from pyro.infer.tracegraph_elbo import JitTraceGraph_ELBO, TraceGraph_ELBO
from pyro.infer.util import enable_validation, is_validation_enabled

__all__ = [
"config_enumerate",
"enable_validation",
"is_validation_enabled",
"ELBO",
"Importance",
"EmpiricalMarginal",
"TracePredictive",
"Importance",
"JitTraceEnum_ELBO",
"JitTraceGraph_ELBO",
"JitTrace_ELBO",
"SVI",
"TracePosterior",
"Trace_ELBO",
"TraceEnum_ELBO",
"TraceGraph_ELBO",
"TracePosterior",
"TracePredictive",
"Trace_ELBO",
]
82 changes: 82 additions & 0 deletions pyro/infer/trace_elbo.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from __future__ import absolute_import, division, print_function

import warnings
import weakref

import torch

import pyro
import pyro.poutine as poutine
from pyro.distributions.util import is_identically_zero
from pyro.infer.elbo import ELBO
Expand Down Expand Up @@ -142,3 +146,81 @@ def loss_and_grads(self, model, guide, *args, **kwargs):
if torch_isnan(loss):
warnings.warn('Encountered NAN loss')
return loss


class JitTrace_ELBO(Trace_ELBO):
"""
Like :class:`Trace_ELBO` but uses :func:`torch.jit.compile` to compile
:meth:`loss_and_grads`.

This works only for a limited set of models:

- Models must have static structure.
- Models must not depend on any global data (except the param store).
- All model inputs that are tensors must be passed in via ``*args``.
- All model inputs that are *not* tensors must be passed in via
``*kwargs``, and these will be fixed to their values on the first
call to :meth:`jit_loss_and_grads`.

.. warning:: Experimental. Interface subject to change.
"""
def loss_and_grads(self, model, guide, *args, **kwargs):
if getattr(self, '_loss_and_surrogate_loss', None) is None:
# populate param store
with poutine.block():
with poutine.trace(param_only=True) as param_capture:
for _ in self._get_traces(model, guide, *args, **kwargs):
pass
self._param_names = list(param_capture.trace.nodes.keys())

# build a closure for loss_and_surrogate_loss
weakself = weakref.ref(self)

@torch.jit.compile(nderivs=1)
def loss_and_surrogate_loss(args_list, param_list):
self = weakself()
loss = 0.0
surrogate_loss = 0.0
for model_trace, guide_trace in self._get_traces(model, guide, *args_list, **kwargs):
elbo_particle = 0
surrogate_elbo_particle = 0
log_r = None

# compute elbo and surrogate elbo
for name, site in model_trace.nodes.items():
if site["type"] == "sample":
elbo_particle = elbo_particle + site["log_prob_sum"]
surrogate_elbo_particle = surrogate_elbo_particle + site["log_prob_sum"]

for name, site in guide_trace.nodes.items():
if site["type"] == "sample":
log_prob, score_function_term, entropy_term = site["score_parts"]

elbo_particle = elbo_particle - site["log_prob_sum"]

if not is_identically_zero(entropy_term):
surrogate_elbo_particle = surrogate_elbo_particle - entropy_term.sum()

if not is_identically_zero(score_function_term):
if log_r is None:
log_r = _compute_log_r(model_trace, guide_trace)
site = log_r.sum_to(site["cond_indep_stack"])
surrogate_elbo_particle = surrogate_elbo_particle + (site * score_function_term).sum()

loss = loss - elbo_particle / self.num_particles
surrogate_loss = surrogate_loss - surrogate_elbo_particle / self.num_particles

return loss, surrogate_loss

self._loss_and_surrogate_loss = loss_and_surrogate_loss

# invoke _loss_and_surrogate_loss
args_list = list(args)
param_list = [pyro.param(name).unconstrained() for name in self._param_names]
loss, surrogate_loss = self._loss_and_surrogate_loss(args_list, param_list)
surrogate_loss.backward() # this line triggers jit compilation
loss = loss.item()

if torch_isnan(loss):
warnings.warn('Encountered NAN loss')
return loss
54 changes: 54 additions & 0 deletions pyro/infer/traceenum_elbo.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from __future__ import absolute_import, division, print_function

import warnings
import weakref

import torch

import pyro
import pyro.poutine as poutine
from pyro.distributions.util import is_identically_zero
from pyro.infer.elbo import ELBO
Expand Down Expand Up @@ -147,3 +151,53 @@ def loss_and_grads(self, model, guide, *args, **kwargs):
if torch_isnan(loss):
warnings.warn('Encountered NAN loss')
return loss


class JitTraceEnum_ELBO(TraceEnum_ELBO):
"""
Like :class:`TraceEnum_ELBO` but uses :func:`torch.jit.compile` to
compile :meth:`loss_and_grads`.

This works only for a limited set of models:

- Models must have static structure.
- Models must not depend on any global data (except the param store).
- All model inputs that are tensors must be passed in via ``*args``.
- All model inputs that are *not* tensors must be passed in via
``*kwargs``, and these will be fixed to their values on the first
call to :meth:`jit_loss_and_grads`.

.. warning:: Experimental. Interface subject to change.
"""
def loss_and_grads(self, model, guide, *args, **kwargs):
if getattr(self, '_differentiable_loss', None) is None:
# populate param store
with poutine.block():
with poutine.trace(param_only=True) as param_capture:
for _ in self._get_traces(model, guide, *args, **kwargs):
pass
self._param_names = list(param_capture.trace.nodes.keys())

# build a closure for differentiable_loss
weakself = weakref.ref(self)

@torch.jit.compile(nderivs=1)
def differentiable_loss(args_list, param_list):
self = weakself()
elbo = 0.0
for model_trace, guide_trace in self._get_traces(model, guide, *args_list, **kwargs):
elbo += _compute_dice_elbo(model_trace, guide_trace)
return elbo * (-1.0 / self.num_particles)

self._differentiable_loss = differentiable_loss

# invoke _differentiable_loss
args_list = list(args)
param_list = [pyro.param(name).unconstrained() for name in self._param_names]
differentiable_loss = self._differentiable_loss(args_list, param_list)
differentiable_loss.backward() # this line triggers jit compilation
loss = differentiable_loss.item()

if torch_isnan(loss):
warnings.warn('Encountered NAN loss')
return loss
84 changes: 81 additions & 3 deletions pyro/infer/tracegraph_elbo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import absolute_import, division, print_function

import warnings
import weakref
from operator import itemgetter

import networkx
Expand Down Expand Up @@ -99,13 +100,13 @@ def _compute_elbo_reparam(model_trace, guide_trace, non_reparam_nodes):
# deal with log p(z|...) terms
for name, site in model_trace.nodes.items():
if site["type"] == "sample":
elbo += torch_item(site["log_prob_sum"])
elbo += site["log_prob_sum"]
surrogate_elbo += site["log_prob_sum"]

# deal with log q(z|...) terms
for name, site in guide_trace.nodes.items():
if site["type"] == "sample":
elbo -= torch_item(site["log_prob_sum"])
elbo -= site["log_prob_sum"]
entropy_term = site["score_parts"].entropy_term
if not is_identically_zero(entropy_term):
surrogate_elbo -= entropy_term.sum()
Expand Down Expand Up @@ -277,7 +278,84 @@ def _loss_and_grads_particle(self, weight, model_trace, guide_trace):
surrogate_loss = -surrogate_elbo
torch_backward(weight * (surrogate_loss + baseline_loss))

loss = -elbo
loss = -torch_item(elbo)
if torch_isnan(loss):
warnings.warn('Encountered NAN loss')
return weight * loss


class JitTraceGraph_ELBO(TraceGraph_ELBO):
"""
Like :class:`TraceGraph_ELBO` but uses :func:`torch.jit.compile` to
compile :meth:`loss_and_grads`.

This works only for a limited set of models:

- Models must have static structure.
- Models must not depend on any global data (except the param store).
- All model inputs that are tensors must be passed in via ``*args``.
- All model inputs that are *not* tensors must be passed in via
``*kwargs``, and these will be fixed to their values on the first
call to :meth:`loss_and_grads`.

.. warning:: Experimental. Interface subject to change.
"""

def loss_and_grads(self, model, guide, *args, **kwargs):
if getattr(self, '_loss_and_surrogate_loss', None) is None:
# populate param store
with poutine.block():
with poutine.trace(param_only=True) as param_capture:
for _ in self._get_traces(model, guide, *args, **kwargs):
pass
self._param_names = list(param_capture.trace.nodes.keys())

# build a closure for loss_and_surrogate_loss
weakself = weakref.ref(self)

@torch.jit.compile(nderivs=1)
def loss_and_surrogate_loss(args_list, param_list):
self = weakself()
loss = 0.0
surrogate_loss = 0.0
for weight, model_trace, guide_trace in self._get_traces(model, guide, *args_list, **kwargs):
model_trace.compute_log_prob()
guide_trace.compute_score_parts()
if is_validation_enabled():
for site in model_trace.nodes.values():
if site["type"] == "sample":
check_site_shape(site, self.max_iarange_nesting)
for site in guide_trace.nodes.values():
if site["type"] == "sample":
check_site_shape(site, self.max_iarange_nesting)

# compute elbo for reparameterized nodes
non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)
elbo, surrogate_elbo = _compute_elbo_reparam(model_trace, guide_trace, non_reparam_nodes)

# the following computations are only necessary if we have non-reparameterizable nodes
baseline_loss = 0.0
if non_reparam_nodes:
downstream_costs, _ = _compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes)
surrogate_elbo_term, baseline_loss = _compute_elbo_non_reparam(guide_trace,
non_reparam_nodes,
downstream_costs)
surrogate_elbo += surrogate_elbo_term

loss = loss - weight * elbo
surrogate_loss = surrogate_loss - weight * surrogate_elbo

return loss, surrogate_loss

self._loss_and_surrogate_loss = loss_and_surrogate_loss

# invoke _loss_and_surrogate_loss
args_list = list(args)
param_list = [pyro.param(name).unconstrained() for name in self._param_names]
loss, surrogate_loss = self._loss_and_surrogate_loss(args_list, param_list)
surrogate_loss.backward() # this line triggers jit compilation
loss = loss.item()

if torch_isnan(loss):
warnings.warn('Encountered NAN loss')
return loss
29 changes: 22 additions & 7 deletions tests/infer/test_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
import pyro
import pyro.distributions as dist
from pyro.distributions.testing import fakes
from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, TraceGraph_ELBO
from pyro.infer import (SVI, JitTrace_ELBO, JitTraceEnum_ELBO, JitTraceGraph_ELBO, Trace_ELBO, TraceEnum_ELBO,
TraceGraph_ELBO)
from pyro.optim import Adam
from tests.common import assert_equal
from tests.common import assert_equal, xfail_param

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -112,7 +113,17 @@ def guide():

@pytest.mark.parametrize("reparameterized", [True, False], ids=["reparam", "nonreparam"])
@pytest.mark.parametrize("subsample", [False, True], ids=["full", "subsample"])
@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO])
@pytest.mark.parametrize("Elbo", [
Trace_ELBO,
TraceGraph_ELBO,
TraceEnum_ELBO,
xfail_param(JitTrace_ELBO,
reason="jit RuntimeError: Unsupported op descriptor: index-2"),
xfail_param(JitTraceGraph_ELBO,
reason="jit RuntimeError: Unsupported op descriptor: index-2"),
xfail_param(JitTraceEnum_ELBO,
reason="jit RuntimeError: Unsupported op descriptor: index-2"),
])
def test_subsample_gradient_sequential(Elbo, reparameterized, subsample):
pyro.clear_param_store()
data = torch.tensor([-0.5, 2.0])
Expand All @@ -134,11 +145,15 @@ def guide():
pyro.sample("z", Normal(loc[ind], scale))

optim = Adam({"lr": 0.1})
elbo = Elbo(num_particles=num_particles, strict_enumeration_warning=False)
inference = SVI(model, guide, optim, loss=elbo)
inference.loss_and_grads(model, guide)
elbo = Elbo(num_particles=10, strict_enumeration_warning=False)
inference = SVI(model, guide, optim, elbo)
iters = num_particles // 10
for _ in range(iters):
inference.loss_and_grads(model, guide)

params = dict(pyro.get_param_store().named_parameters())
actual_grads = {name: param.grad.detach().cpu().numpy() for name, param in params.items()}
actual_grads = {name: param.grad.detach().cpu().numpy() / iters
for name, param in params.items()}

expected_grads = {'loc': np.array([0.5, -2.0]), 'scale': np.array([2.0])}
for name in sorted(params):
Expand Down
Loading