Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MuE distributions for Pyro. #2728

Merged
merged 93 commits into from
Mar 23, 2021
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
93 commits
Select commit Hold shift + click to select a range
acbd882
Initialize main project files.
EWeinstein Dec 28, 2020
cbf6102
Module for converting MuE parameters (in particular, profile HMM) to …
EWeinstein Dec 28, 2020
1fe82b1
Test state arranger.
EWeinstein Dec 28, 2020
6e2868b
Debug state arranger tests.
EWeinstein Dec 29, 2020
7ddecb6
Variable length discrete hmm class and log probability function.
EWeinstein Dec 30, 2020
c2458f4
Test for log probability of variable length hmm.
EWeinstein Dec 30, 2020
1ea5af6
Simple profile hmm example model.
EWeinstein Dec 30, 2020
2f72384
Switch to standard hmm log probability convention.
EWeinstein Dec 30, 2020
1725392
FactorMuE forward pass.
EWeinstein Dec 31, 2020
bfba6ae
FactorMuE plots, debug.
EWeinstein Dec 31, 2020
c017cf5
Multistage training.
EWeinstein Dec 31, 2020
0ae6070
Cleanup.
EWeinstein Dec 31, 2020
112e86a
Add parser, adjust plot saving to avoid creating new folders.
EWeinstein Dec 31, 2020
f80100f
Cleanup imports.
EWeinstein Dec 31, 2020
a11aef0
More extensive testing for VariableLengthHMM.
EWeinstein Jan 14, 2021
81e0d49
Docs for mue.
EWeinstein Jan 15, 2021
e18b033
Shape tests and trivial case tests for profile statearranger.
EWeinstein Jan 15, 2021
a6ad405
Cleaning up naming conventions and doc string conventions.
EWeinstein Jan 16, 2021
d74a6f9
Docstrings with parameter details.
EWeinstein Jan 25, 2021
827721a
Improve indexing conventions, add unit tests.
EWeinstein Feb 8, 2021
e973c23
Rename files and build complete set of options for FactorMuE
EWeinstein Feb 10, 2021
c98ece0
Cleanup.
EWeinstein Feb 10, 2021
fa90df4
Basic dataloaders, and start rearranging training to be part of model.
EWeinstein Feb 11, 2021
0d7ea6c
subsampling inference provided with model
EWeinstein Feb 11, 2021
8cb2642
FactorMuE test.
EWeinstein Feb 12, 2021
7977022
Inference, length model, tests for profile hmm.
EWeinstein Feb 19, 2021
5d085cf
FactorMuE example full input options.
EWeinstein Feb 19, 2021
921eebe
Debug saving.
EWeinstein Feb 20, 2021
f07ce6d
Clean up
EWeinstein Feb 20, 2021
f982316
Add FactorMuE to test_examples.py
EWeinstein Feb 20, 2021
42e97bc
Files autochanged by make format.
EWeinstein Feb 20, 2021
cca23e1
Revert "Add FactorMuE to test_examples.py"
EWeinstein Feb 20, 2021
9132ae5
Revert "Files autochanged by make format."
EWeinstein Feb 20, 2021
f696c16
Add back in test examples lines.
EWeinstein Feb 20, 2021
cb31d2d
put in constraints
EWeinstein Feb 20, 2021
5c9500a
Merge remote-tracking branch 'upstream/dev' into mue
EWeinstein Feb 20, 2021
dd1a383
Git ignore .fasta files.
EWeinstein Feb 20, 2021
d2c513c
Make format's automated changes.
EWeinstein Feb 20, 2021
5c77f16
Revert "Make format's automated changes."
EWeinstein Feb 20, 2021
161f1e1
Adjust license headers.
EWeinstein Feb 20, 2021
1a86f9a
Profile and Factor example tests edited and added to main list.
EWeinstein Feb 20, 2021
ae861cc
adjust prior names
EWeinstein Feb 21, 2021
08f16cb
Rearrange profile HMM for jit compilation.
EWeinstein Feb 21, 2021
6c6a846
Debug jit compile ELBO in profile HMM
EWeinstein Feb 21, 2021
ede13d4
Reconfigure FactorMuE for jit compilation
EWeinstein Feb 21, 2021
1dcd4a8
Debug jit compilation for FactorMuE.
EWeinstein Feb 21, 2021
07acd22
Beta annealing in FactorMuE.
EWeinstein Feb 21, 2021
1506697
Evaluate train test elbo and perplexity for ProfileHMM
EWeinstein Feb 21, 2021
a4a620e
Heldout likelihood evaluation for factormue.
EWeinstein Feb 21, 2021
26061b7
switch to local elbo evaluation and add more options to factor example.
EWeinstein Feb 25, 2021
33730bd
Cuda option.
EWeinstein Feb 25, 2021
2304b0e
pin memory option
EWeinstein Feb 25, 2021
04e9a22
Fix data tensor initialization.
EWeinstein Feb 25, 2021
e07cca7
Move data to cuda.
EWeinstein Feb 26, 2021
fd52f5b
Transfer results back to cpu.
EWeinstein Feb 26, 2021
15e4b17
Move results back to cpu.
EWeinstein Feb 26, 2021
028e305
Move more results to cpu.
EWeinstein Feb 26, 2021
2311edc
Speed up initialization.
EWeinstein Feb 26, 2021
8b02abe
Move to device in generator.
EWeinstein Feb 26, 2021
a7b452b
Adjust cuda device transfer.
EWeinstein Feb 26, 2021
490be6c
Move data to device for now for ease.
EWeinstein Feb 26, 2021
c6e7f2d
Try another way of fixing cuda error.
EWeinstein Feb 26, 2021
f550447
Try disabling device transfer in dataloader entirely.
EWeinstein Feb 26, 2021
ef943cc
Disable device handling entirely.
EWeinstein Feb 26, 2021
ec9b216
Add back in device handling in random split
EWeinstein Feb 26, 2021
c247c99
Try moving data lengths to cuda?
EWeinstein Feb 26, 2021
50e2149
Try another combination of device calls.
EWeinstein Feb 26, 2021
f06a1b4
Try adjusting generator again.
EWeinstein Feb 26, 2021
a455ba8
Try removing generator statement entirely.
EWeinstein Feb 26, 2021
511b51e
Try removing data util random split call.
EWeinstein Feb 26, 2021
1301e97
Clean up comments.
EWeinstein Feb 26, 2021
690dae8
Fix embedding ordering.
EWeinstein Feb 27, 2021
999bbe2
Update profile HMM example.
EWeinstein Feb 27, 2021
1fcbd67
Update tests, improve alphabet handling.
EWeinstein Feb 27, 2021
aa2850d
Adjust cuda defaults.
EWeinstein Feb 27, 2021
23ddc45
Move back to cpu for plotting.
EWeinstein Feb 27, 2021
f70115e
Merge remote-tracking branch 'upstream/dev' into mue
EWeinstein Mar 15, 2021
ca345f0
Documentation for profile hmm model.
EWeinstein Mar 15, 2021
cdea31d
Docs for FactorMuE model.
EWeinstein Mar 16, 2021
d8c9546
Cleaned up docs.
EWeinstein Mar 16, 2021
8542820
Tutorials.
EWeinstein Mar 16, 2021
5199844
Add example run scripts.
EWeinstein Mar 16, 2021
0085283
Make format changes.
EWeinstein Mar 16, 2021
b708557
Fix example boolean inputs
EWeinstein Mar 16, 2021
a4fc2e7
Wording edit.
EWeinstein Mar 16, 2021
3089861
Add jit and cuda calls to test_examples.
EWeinstein Mar 17, 2021
aa388a3
Add stop codon handling
EWeinstein Mar 17, 2021
df10c7b
Remove old length modeling mechanism in favor of using stop symbols
EWeinstein Mar 17, 2021
767d9e0
Option to keep data on cpu
EWeinstein Mar 17, 2021
7617acb
check device of seq data
EWeinstein Mar 17, 2021
839adf5
CPU storage option for profile HMM
EWeinstein Mar 17, 2021
4dcf385
Update example tests
EWeinstein Mar 17, 2021
2e192a8
Addressed fritzo comments
EWeinstein Mar 22, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
310 changes: 310 additions & 0 deletions examples/contrib/mue/FactorMuE.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,310 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

"""
A PCA model with a MuE emission (FactorMuE). Uses the MuE package.
"""

import argparse
import datetime
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.nn.functional import softplus
from torch.optim import Adam

import pyro
import pyro.distributions as dist

from pyro.contrib.mue.statearrangers import profile
from pyro.contrib.mue.variablelengthhmm import VariableLengthDiscreteHMM

from pyro.infer import SVI, Trace_ELBO
from pyro.optim import MultiStepLR
import pyro.poutine as poutine


class Encoder(nn.Module):
def __init__(self, obs_seq_length, alphabet_length, z_dim):
super().__init__()

self.input_size = obs_seq_length * alphabet_length
self.f1_mn = nn.Linear(self.input_size, z_dim)
self.f1_sd = nn.Linear(self.input_size, z_dim)

def forward(self, data):

data = data.reshape(-1, self.input_size)
z_loc = self.f1_mn(data)
z_scale = softplus(self.f1_sd(data))

return z_loc, z_scale


class Decoder(nn.Module):
def __init__(self, latent_seq_length, alphabet_length, z_dim):
super().__init__()

self.latent_seq_length = latent_seq_length
self.alphabet_length = alphabet_length
self.output_size = 2 * (latent_seq_length+1) * alphabet_length
self.f = nn.Linear(z_dim, self.output_size)

def forward(self, z):

seq = self.f(z)
seq = seq.reshape([-1, 2, self.latent_seq_length+1,
self.alphabet_length])
return seq


class FactorMuE(nn.Module):
fritzo marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, obs_seq_length, alphabet_length, z_dim,
scale_factor=1.,
latent_seq_length=None, prior_scale=1.,
indel_prior_strength=10., inverse_temp_prior=100.):
super().__init__()

# Constants.
assert isinstance(obs_seq_length, int) and obs_seq_length > 0
self.obs_seq_length = obs_seq_length
if latent_seq_length is None:
latent_seq_length = obs_seq_length
else:
assert isinstance(latent_seq_length, int) and latent_seq_length > 0
self.latent_seq_length = latent_seq_length
assert isinstance(alphabet_length, int) and alphabet_length > 0
self.alphabet_length = alphabet_length
assert isinstance(z_dim, int) and z_dim > 0
self.z_dim = z_dim

# Parameter shapes.
self.seq_shape = (latent_seq_length+1, alphabet_length)
self.indel_shape = (latent_seq_length+1, 3, 2)

# Priors.
assert isinstance(prior_scale, float)
self.prior_scale = torch.tensor(prior_scale)
assert isinstance(indel_prior_strength, float)
self.indel_prior = torch.tensor([indel_prior_strength, 0.])
assert isinstance(inverse_temp_prior, float)
self.inverse_temp_prior = torch.tensor(inverse_temp_prior)

# Batch control.
self.scale_factor = scale_factor

# Initialize layers.
self.encoder = Encoder(obs_seq_length, alphabet_length, z_dim)
self.decoder = Decoder(latent_seq_length, alphabet_length, z_dim)
self.statearrange = profile(latent_seq_length)

def model(self, data):

pyro.module("decoder", self.decoder)

# Indel probabilities.
insert = pyro.sample("insert", dist.Normal(
self.indel_prior * torch.ones(self.indel_shape),
self.prior_scale * torch.ones(self.indel_shape)).to_event(3))
insert_logits = insert - insert.logsumexp(-1, True)
delete = pyro.sample("delete", dist.Normal(
self.indel_prior * torch.ones(self.indel_shape),
self.prior_scale * torch.ones(self.indel_shape)).to_event(3))
delete_logits = delete - delete.logsumexp(-1, True)

# Inverse temperature.
inverse_temp = pyro.sample("inverse_temp", dist.Normal(
self.inverse_temp_prior, torch.tensor(1.)))

with pyro.plate("batch", data.shape[0]), poutine.scale(
scale=self.scale_factor):
# Sample latent variable from prior.
z = pyro.sample("latent", dist.Normal(
torch.zeros(self.z_dim), torch.ones(self.z_dim)).to_event(1))
# Decode latent sequence.
latent_seq = self.decoder(z)
# Construct ancestral and insertion sequences.
ancestor_seq_logits = (latent_seq[..., 0, :, :] *
softplus(inverse_temp))
ancestor_seq_logits = (ancestor_seq_logits -
ancestor_seq_logits.logsumexp(-1, True))
insert_seq_logits = (latent_seq[..., 1, :, :] *
softplus(inverse_temp))
insert_seq_logits = (insert_seq_logits -
insert_seq_logits.logsumexp(-1, True))
# Construct HMM parameters.
initial_logits, transition_logits, observation_logits = (
self.statearrange(ancestor_seq_logits, insert_seq_logits,
insert_logits, delete_logits))
# Draw samples.
pyro.sample("obs",
VariableLengthDiscreteHMM(initial_logits,
transition_logits,
observation_logits),
obs=data)

def guide(self, data):
# Register encoder with pyro.
pyro.module("encoder", self.encoder)

# Indel probabilities.
insert_q_mn = pyro.param("insert_q_mn",
torch.ones(self.indel_shape)
* self.indel_prior)
insert_q_sd = pyro.param("insert_q_sd",
torch.zeros(self.indel_shape))
pyro.sample("insert", dist.Normal(
insert_q_mn, softplus(insert_q_sd)).to_event(3))
delete_q_mn = pyro.param("delete_q_mn",
torch.ones(self.indel_shape)
* self.indel_prior)
delete_q_sd = pyro.param("delete_q_sd",
torch.zeros(self.indel_shape))
pyro.sample("delete", dist.Normal(
delete_q_mn, softplus(delete_q_sd)).to_event(3))
# Inverse temperature.
inverse_temp_q_mn = pyro.param("inverse_temp_q_mn", torch.tensor(0.))
inverse_temp_q_sd = pyro.param("inverse_temp_q_sd", torch.tensor(0.))
pyro.sample("inverse_temp", dist.Normal(
inverse_temp_q_mn, softplus(inverse_temp_q_sd)))

# Per data latent variables.
with pyro.plate("batch", data.shape[0]), poutine.scale(
scale=self.scale_factor):
# Encode seq.
z_loc, z_scale = self.encoder(data)
# Sample.
pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))

def reconstruct_ancestor_seq(self, data, inverse_temp=1.):
# Encode seq.
z_loc = self.encoder(data)[0]
# Reconstruct
latent_seq = self.decoder(z_loc)
# Construct ancestral sequence.
ancestor_seq_logits = latent_seq[..., 0, :, :] * softplus(inverse_temp)
ancestor_seq_logits = (ancestor_seq_logits -
ancestor_seq_logits.logsumexp(-1, True))
return torch.exp(ancestor_seq_logits)


def main(args):

torch.manual_seed(9)
torch.set_default_tensor_type('torch.DoubleTensor')

small_test = args.test

if small_test:
mult_dat = 1
mult_step = 1
else:
mult_dat = 10
mult_step = 400

# Construct example dataset.
xs = [torch.tensor([[0., 1.],
[1., 0.],
[0., 1.],
[0., 1.],
[1., 0.],
[0., 0.]]),
torch.tensor([[0., 1.],
[1., 0.],
[1., 0.],
[0., 1.],
[0., 0.],
[0., 0.]]),
torch.tensor([[0., 1.],
[1., 0.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 0.]])]
data = torch.cat([xs[0][None, :, :] for j in range(6*mult_dat)] +
[xs[1][None, :, :] for j in range(4*mult_dat)] +
[xs[2][None, :, :] for j in range(4*mult_dat)], dim=0)
# Set up inference.
obs_seq_length, alphabet_length, z_dim = 6, 2, 2
# adam_params = {"lr": 0.1, "betas": (0.90, 0.999)}
scheduler = MultiStepLR({'optimizer': Adam,
'optim_args': {'lr': 0.1},
'milestones': [20, 100, 1000, 2000],
'gamma': 0.5})
# optimizer = Adam(adam_params)
model = FactorMuE(obs_seq_length, alphabet_length, z_dim)

svi = SVI(model.model, model.guide, scheduler, loss=Trace_ELBO())
n_steps = 10*mult_step

# Run inference.
losses = []
t0 = datetime.datetime.now()
for step in range(n_steps):

loss = svi.step(data)
losses.append(loss)
scheduler.step()
if step % 10 == 0:
print(step, loss, ' ', datetime.datetime.now() - t0)

# Plots.
time_stamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
plt.figure(figsize=(6, 6))
plt.plot(losses)
plt.xlabel('step')
plt.ylabel('loss')
plt.savefig('FactorMuE_plot.loss_{}.pdf'.format(time_stamp))

plt.figure(figsize=(6, 6))
latent = model.encoder(data)[0].detach()
plt.scatter(latent[:, 0], latent[:, 1])
plt.xlabel('z_1')
plt.ylabel('z_2')
plt.savefig('FactorMuE_plot.latent_{}.pdf'.format(time_stamp))

plt.figure(figsize=(6, 6))
decoder_bias = pyro.param('decoder$$$f.bias').detach()
decoder_bias = decoder_bias.reshape(
[-1, 2, model.latent_seq_length+1, model.alphabet_length])
plt.plot(decoder_bias[0, 0, :, 1])
plt.xlabel('position')
plt.ylabel('bias for character 1')
plt.savefig('FactorMuE_plot.decoder_bias_{}.pdf'.format(time_stamp))

for xi, x in enumerate(xs):
reconstruct_x = model.reconstruct_ancestor_seq(
x, pyro.param("inverse_temp_q_mn")).detach()
plt.figure(figsize=(6, 6))
plt.plot(reconstruct_x[0, :, 1], label="reconstruct")
plt.plot(x[:, 1], label="data")
plt.xlabel('position')
plt.ylabel('probability of character 1')
plt.legend()
plt.savefig('FactorMuE_plot.reconstruction_{}_{}.pdf'.format(
xi, time_stamp))

plt.figure(figsize=(6, 6))
insert = pyro.param("insert_q_mn").detach()
insert_expect = torch.exp(insert - insert.logsumexp(-1, True))
plt.plot(insert_expect[:, :, 1].numpy())
plt.xlabel('position')
plt.ylabel('probability of insert')
plt.savefig('FactorMuE_plot.insert_prob_{}.pdf'.format(time_stamp))
plt.figure(figsize=(6, 6))
delete = pyro.param("delete_q_mn").detach()
delete_expect = torch.exp(delete - delete.logsumexp(-1, True))
plt.plot(delete_expect[:, :, 1].numpy())
plt.xlabel('position')
plt.ylabel('probability of delete')
plt.savefig('FactorMuE_plot.delete_prob_{}.pdf'.format(time_stamp))


if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Basic Factor MuE model.")
parser.add_argument('-t', '--test', action='store_true', default=False,
help='small dataset, a few steps')
args = parser.parse_args()
main(args)
Loading