Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tutorials using normalizing flows #3302

Merged
merged 15 commits into from
Jan 14, 2024
5 changes: 5 additions & 0 deletions docs/source/contrib.zuko.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Zuko in Pyro
============

.. automodule:: pyro.contrib.zuko
:members:
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ Pyro Documentation
contrib.randomvariable
contrib.timeseries
contrib.tracking
contrib.zuko


Indices and tables
Expand Down
77 changes: 77 additions & 0 deletions pyro/contrib/zuko.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

"""
This file contains helpers to use `Zuko <https://zuko.readthedocs.io/>`_-based
normalizing flows within Pyro piplines.

Accompanying tutorials can be found at `tutorial/svi_flow_guide.ipynb` and
`tutorial/vae_flow_prior.ipynb`.
"""

import torch
from torch import Size, Tensor

import pyro


class Zuko2Pyro(pyro.distributions.TorchDistribution):
francois-rozet marked this conversation as resolved.
Show resolved Hide resolved
r"""Wraps a Zuko distribution as a Pyro distribution.

:param dist: A distribution instance.
:type dist: torch.distributions.Distribution

.. code-block:: python

flow = zuko.flows.MAF(features=5)

# flow() is a torch.distributions.Distribution

dist = flow()
x = dist.sample((2, 3))
log_p = dist.log_prob(x)

# Zuko2Pyro(flow()) is a pyro.distributions.Distribution

dist = Zuko2Pyro(flow())
x = dist((2, 3))
log_p = dist.log_prob(x)

with pyro.plate("data", 42):
z = pyro.sample("z", dist)
"""

def __init__(self, dist: torch.distributions.Distribution):
self.dist = dist
self.cache = {}

@property
def has_rsample(self) -> bool:
return self.dist.has_rsample

@property
def event_shape(self) -> Size:
return self.dist.event_shape

@property
def batch_shape(self) -> Size:
return self.dist.batch_shape

def __call__(self, shape: Size = ()) -> Tensor:
if hasattr(self.dist, "rsample_and_log_prob"): # fast sampling + scoring
x, self.cache[x] = self.dist.rsample_and_log_prob(shape)
elif self.has_rsample:
x = self.dist.rsample(shape)
else:
x = self.dist.sample(shape)

return x

def log_prob(self, x: Tensor) -> Tensor:
if x in self.cache:
return self.cache[x]
francois-rozet marked this conversation as resolved.
Show resolved Hide resolved
else:
return self.dist.log_prob(x)

def expand(self, *args, **kwargs):
return Zuko2Pyro(self.dist.expand(*args, **kwargs))
56 changes: 56 additions & 0 deletions tests/contrib/test_zuko.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0


import pytest
import torch

import pyro
from pyro.contrib.zuko import Zuko2Pyro
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam


@pytest.mark.parametrize("multivariate", [True, False])
def test_Zuko2Pyro(multivariate: bool):
francois-rozet marked this conversation as resolved.
Show resolved Hide resolved
# Distribution
if multivariate:
normal = torch.distributions.MultivariateNormal
mu = torch.zeros(3)
sigma = torch.eye(3)
else:
normal = torch.distributions.Normal
mu = torch.zeros(())
sigma = torch.ones(())

dist = normal(mu, sigma)

# Sample
x1 = pyro.sample("x1", Zuko2Pyro(dist))

assert x1.shape == dist.event_shape

# Sample within plate
with pyro.plate("data", 4):
x2 = pyro.sample("x2", Zuko2Pyro(dist))

assert x2.shape == (4, *dist.event_shape)

# SVI
def model():
pyro.sample("a", Zuko2Pyro(dist))

with pyro.plate("data", 4):
pyro.sample("b", Zuko2Pyro(dist))

def guide():
mu_ = pyro.param("mu", mu)
sigma_ = pyro.param("sigma", sigma)

pyro.sample("a", Zuko2Pyro(normal(mu_, sigma_)))

with pyro.plate("data", 4):
pyro.sample("b", Zuko2Pyro(normal(mu_, sigma_)))

svi = SVI(model, guide, optim=Adam({"lr": 1e-3}), loss=Trace_ELBO())
svi.step()
4 changes: 3 additions & 1 deletion tutorial/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ List of Tutorials
jit
svi_horovod
svi_lightning
svi_flow_guide

.. toctree::
:maxdepth: 1
Expand All @@ -106,7 +107,8 @@ List of Tutorials
vae
ss-vae
cvae
normalizing_flows_i
normalizing_flows_intro
vae_flow_prior
dmm
air
cevae
Expand Down

Large diffs are not rendered by default.

238 changes: 238 additions & 0 deletions tutorial/source/svi_flow_guide.ipynb

Large diffs are not rendered by default.

Loading
Loading