Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add modified Bayesian regression tutorial with more direct PyTorch usage #2996

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from

Conversation

eb8680
Copy link
Member

@eb8680 eb8680 commented Dec 14, 2021

This PR is an attempt at making a couple small changes to the Pyro and PyroModule API to make PyroModule more compatible with vanilla PyTorch programming idioms. The API changes are simple, although the implementations inside PyroModule are a bit hacky and may not yet be correct.

Changes:

  • Adds new global configuration toggle pyro.enable_module_local_param() for PyroModule parameters to be stored locally, rather than the global parameter store. Currently implemented by associating a new ParamStoreDict object with each PyroModule instance, which may not be ideal.
  • Adds backwards-compatible __call__ method to pyro.infer.ELBO that returns a torch.nn.Module bound to a specific model and guide, allowing direct use of the PyTorch JIT API (e.g. torch.jit.trace)
  • Forks Bayesian regression tutorial into a PyTorch API usage tutorial to illustrate a PyTorch-native programming style facilitated by these changes and PyroModule

… than the global parameter store. Add backwards-compatible __call__ method to pyro.infer.ELBO that returns a Module bound to a specific model and guide, allowing direct use of the PyTorch JIT API. Fork Bayesian regression tutorial into a PyTorch API usage tutorial to illustrate a PyTorch-native programming style facilitated by these changes and PyroModule.
@eb8680
Copy link
Member Author

eb8680 commented Dec 14, 2021

For context, here is a condensed training loop from the tutorial notebook that I was trying to enable:

# new: keep PyroParams out of the global parameter store
pyro.enable_module_local_param(True)

class BayesianRegression(PyroModule):
    ...

# Create fresh copies of model, guide, elbo
model = BayesianRegression(3, 1)
guide = AutoDiagonalNormal(model)
elbo = Trace_ELBO(num_particles=10)

# new: bind elbo to (model, guide) pair
elbo = elbo(model, guide)

# Populate guide parameters
elbo(x_data, y_data);
# new: use torch.optim directly
optim = torch.optim.Adam(guide.parameters(), lr=0.03)

# Temporarily disable runtime validation and compile ELBO
with pyro.validation_enabled(False):
    # new: use torch.jit.trace directly
    elbo = torch.jit.trace(elbo, (x_data, y_data), check_trace=False, strict=False)

# optimize
for j in range(1500):
    loss = elbo(x_data, y_data)
    optim.zero_grad()
    loss.backward()
    optim.step()

# prediction
predict_fn = Predictive(model, guide=guide, num_samples=800)
# new: use torch.jit.trace directly
predict_fn = torch.jit.trace(predict_fn, (x_data,), check_trace=False, strict=False)
samples = predict_fn(x_data)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant