Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch Lightning example (#3189) #2

Merged
merged 1 commit into from
Mar 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/svi_horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# https://horovod.readthedocs.io/en/stable
#
# This assumes you have installed horovod, e.g. via
# pip install pyro[horovod]
# pip install pyro-ppl[horovod]
# For detailed instructions see
# https://horovod.readthedocs.io/en/stable/install.html
# On my mac laptop I was able to install horovod with
Expand Down
116 changes: 116 additions & 0 deletions examples/svi_lightning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

# Distributed training via Pytorch Lightning.
#
# This tutorial demonstrates how to distribute SVI training across multiple
# machines (or multiple GPUs on one or more machines) using the PyTorch Lightning
# library. PyTorch Lightning enables data-parallel training by aggregating stochastic
# gradients at each step of training. We focus on integration between PyTorch Lightning and Pyro.
# For further details on distributed computing with PyTorch Lightning, see
# https://lightning.ai/docs/pytorch/latest
#
# This assumes you have installed pytorch lightning, e.g. via
# pip install pyro-ppl[lightning]

import argparse

import pytorch_lightning as pl
import torch

import pyro
import pyro.distributions as dist
from pyro.infer import Trace_ELBO
from pyro.infer.autoguide import AutoNormal
from pyro.nn import PyroModule


# We define a model as usual, with no reference to Pytorch Lightning.
# This model is data parallel and supports subsampling.
class Model(PyroModule):
def __init__(self, size):
super().__init__()
self.size = size

def forward(self, covariates, data=None):
coeff = pyro.sample("coeff", dist.Normal(0, 1))
bias = pyro.sample("bias", dist.Normal(0, 1))
scale = pyro.sample("scale", dist.LogNormal(0, 1))

# Since we'll use a distributed dataloader during training, we need to
# manually pass minibatches of (covariates,data) that are smaller than
# the full self.size. In particular we cannot rely on pyro.plate to
# automatically subsample, since that would lead to all workers drawing
# identical subsamples.
with pyro.plate("data", self.size, len(covariates)):
loc = bias + coeff * covariates
return pyro.sample("obs", dist.Normal(loc, scale), obs=data)


# We define an ELBO loss, a PyTorch optimizer, and a training step in our PyroLightningModule.
# Note that we are using a PyTorch optimizer instead of a Pyro optimizer and
# we are using ``training_step`` instead of Pyro's SVI machinery.
class PyroLightningModule(pl.LightningModule):
def __init__(self, loss_fn: pyro.infer.elbo.ELBOModule, lr: float):
super().__init__()
self.loss_fn = loss_fn
self.model = loss_fn.model
self.guide = loss_fn.guide
self.lr = lr
self.predictive = pyro.infer.Predictive(
self.model, guide=self.guide, num_samples=1
)

def forward(self, *args):
return self.predictive(*args)

def training_step(self, batch, batch_idx):
"""Training step for Pyro training."""
loss = self.loss_fn(*batch)
# Logging to TensorBoard by default
self.log("train_loss", loss)
return loss

def configure_optimizers(self):
"""Configure an optimizer."""
return torch.optim.Adam(self.loss_fn.parameters(), lr=self.lr)


def main(args):
# Create a model, synthetic data, a guide, and a lightning module.
pyro.set_rng_seed(args.seed)
pyro.settings.set(module_local_params=True)
model = Model(args.size)
covariates = torch.randn(args.size)
data = model(covariates)
guide = AutoNormal(model)
loss_fn = Trace_ELBO()(model, guide)
training_plan = PyroLightningModule(loss_fn, args.learning_rate)

# Create a dataloader.
dataset = torch.utils.data.TensorDataset(covariates, data)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size)

# All relevant parameters need to be initialized before ``configure_optimizer`` is called.
# Since we used AutoNormal guide our parameters have not be initialized yet.
# Therefore we initialize the model and guide by running one mini-batch through the loss.
mini_batch = dataset[: args.batch_size]
loss_fn(*mini_batch)

# Run stochastic variational inference using PyTorch Lightning Trainer.
trainer = pl.Trainer.from_argparse_args(args)
trainer.fit(training_plan, train_dataloaders=dataloader)


if __name__ == "__main__":
assert pyro.__version__.startswith("1.8.4")
parser = argparse.ArgumentParser(
description="Distributed training via PyTorch Lightning"
)
parser.add_argument("--size", default=1000000, type=int)
parser.add_argument("--batch_size", default=100, type=int)
parser.add_argument("--learning_rate", default=0.01, type=float)
parser.add_argument("--seed", default=20200723, type=int)
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
main(args)
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@
"yapf",
],
"horovod": ["horovod[pytorch]>=0.19"],
"lightning": ["pytorch_lightning"],
"funsor": [
# This must be a released version when Pyro is released.
# "funsor[torch] @ git+git://github.com/pyro-ppl/funsor.git@7bb52d0eae3046d08a20d1b288544e1a21b4f461",
Expand Down
8 changes: 8 additions & 0 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@ def wrapper(*args, **kwargs):
horovod is None, reason="horovod is not available"
)

try:
import pytorch_lightning
except ImportError:
pytorch_lightning = None
requires_lightning = pytest.mark.skipif(
pytorch_lightning is None, reason="pytorch lightning is not available"
)

try:
import funsor
except ImportError:
Expand Down
9 changes: 9 additions & 0 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
requires_cuda,
requires_funsor,
requires_horovod,
requires_lightning,
xfail_param,
)

Expand Down Expand Up @@ -110,6 +111,10 @@
"sparse_gamma_def.py --num-epochs=2 --eval-particles=2 --eval-frequency=1 --guide auto",
"sparse_gamma_def.py --num-epochs=2 --eval-particles=2 --eval-frequency=1 --guide easy",
"svi_horovod.py --num-epochs=2 --size=400 --no-horovod",
pytest.param(
"svi_lightning.py --max_epochs=2 --size=400 --accelerator cpu --devices 1",
marks=[requires_lightning],
),
"toy_mixture_model_discrete_enumeration.py --num-steps=1",
"sparse_regression.py --num-steps=100 --num-data=100 --num-dimensions 11",
"vae/ss_vae_M2.py --num-epochs=1",
Expand Down Expand Up @@ -177,6 +182,10 @@
"sir_hmc.py -t=2 -w=2 -n=4 -d=2 -p=10000 --sequential --cuda",
"sir_hmc.py -t=2 -w=2 -n=4 -d=100 -p=10000 --cuda",
"svi_horovod.py --num-epochs=2 --size=400 --cuda --no-horovod",
pytest.param(
"svi_lightning.py --max_epochs=2 --size=400 --accelerator gpu --devices 1",
marks=[requires_lightning],
),
"vae/vae.py --num-epochs=1 --cuda",
"vae/ss_vae_M2.py --num-epochs=1 --cuda",
"vae/ss_vae_M2.py --num-epochs=1 --aux-loss --cuda",
Expand Down
1 change: 1 addition & 0 deletions tutorial/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ List of Tutorials
prior_predictive
jit
svi_horovod
svi_lightning

.. toctree::
:maxdepth: 1
Expand Down
17 changes: 17 additions & 0 deletions tutorial/source/svi_lightning.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
Example: distributed training via PyTorch Lightning
===================================================

This script passes argparse arguments to PyTorch Lightning ``Trainer`` automatically_, for example::

$ python examples/svi_lightning.py --accelerator gpu --devices 2 --max_epochs 100 --strategy ddp

.. _automatically: https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#trainer-in-python-scripts

`View svi_lightning.py on github`__

.. _github: https://github.com/pyro-ppl/pyro/blob/dev/examples/svi_lightning.py

__ github_

.. literalinclude:: ../../examples/svi_lightning.py
:language: python