Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MuE distributions for Pyro. #2728

Merged
merged 93 commits into from
Mar 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
93 commits
Select commit Hold shift + click to select a range
acbd882
Initialize main project files.
EWeinstein Dec 28, 2020
cbf6102
Module for converting MuE parameters (in particular, profile HMM) to …
EWeinstein Dec 28, 2020
1fe82b1
Test state arranger.
EWeinstein Dec 28, 2020
6e2868b
Debug state arranger tests.
EWeinstein Dec 29, 2020
7ddecb6
Variable length discrete hmm class and log probability function.
EWeinstein Dec 30, 2020
c2458f4
Test for log probability of variable length hmm.
EWeinstein Dec 30, 2020
1ea5af6
Simple profile hmm example model.
EWeinstein Dec 30, 2020
2f72384
Switch to standard hmm log probability convention.
EWeinstein Dec 30, 2020
1725392
FactorMuE forward pass.
EWeinstein Dec 31, 2020
bfba6ae
FactorMuE plots, debug.
EWeinstein Dec 31, 2020
c017cf5
Multistage training.
EWeinstein Dec 31, 2020
0ae6070
Cleanup.
EWeinstein Dec 31, 2020
112e86a
Add parser, adjust plot saving to avoid creating new folders.
EWeinstein Dec 31, 2020
f80100f
Cleanup imports.
EWeinstein Dec 31, 2020
a11aef0
More extensive testing for VariableLengthHMM.
EWeinstein Jan 14, 2021
81e0d49
Docs for mue.
EWeinstein Jan 15, 2021
e18b033
Shape tests and trivial case tests for profile statearranger.
EWeinstein Jan 15, 2021
a6ad405
Cleaning up naming conventions and doc string conventions.
EWeinstein Jan 16, 2021
d74a6f9
Docstrings with parameter details.
EWeinstein Jan 25, 2021
827721a
Improve indexing conventions, add unit tests.
EWeinstein Feb 8, 2021
e973c23
Rename files and build complete set of options for FactorMuE
EWeinstein Feb 10, 2021
c98ece0
Cleanup.
EWeinstein Feb 10, 2021
fa90df4
Basic dataloaders, and start rearranging training to be part of model.
EWeinstein Feb 11, 2021
0d7ea6c
subsampling inference provided with model
EWeinstein Feb 11, 2021
8cb2642
FactorMuE test.
EWeinstein Feb 12, 2021
7977022
Inference, length model, tests for profile hmm.
EWeinstein Feb 19, 2021
5d085cf
FactorMuE example full input options.
EWeinstein Feb 19, 2021
921eebe
Debug saving.
EWeinstein Feb 20, 2021
f07ce6d
Clean up
EWeinstein Feb 20, 2021
f982316
Add FactorMuE to test_examples.py
EWeinstein Feb 20, 2021
42e97bc
Files autochanged by make format.
EWeinstein Feb 20, 2021
cca23e1
Revert "Add FactorMuE to test_examples.py"
EWeinstein Feb 20, 2021
9132ae5
Revert "Files autochanged by make format."
EWeinstein Feb 20, 2021
f696c16
Add back in test examples lines.
EWeinstein Feb 20, 2021
cb31d2d
put in constraints
EWeinstein Feb 20, 2021
5c9500a
Merge remote-tracking branch 'upstream/dev' into mue
EWeinstein Feb 20, 2021
dd1a383
Git ignore .fasta files.
EWeinstein Feb 20, 2021
d2c513c
Make format's automated changes.
EWeinstein Feb 20, 2021
5c77f16
Revert "Make format's automated changes."
EWeinstein Feb 20, 2021
161f1e1
Adjust license headers.
EWeinstein Feb 20, 2021
1a86f9a
Profile and Factor example tests edited and added to main list.
EWeinstein Feb 20, 2021
ae861cc
adjust prior names
EWeinstein Feb 21, 2021
08f16cb
Rearrange profile HMM for jit compilation.
EWeinstein Feb 21, 2021
6c6a846
Debug jit compile ELBO in profile HMM
EWeinstein Feb 21, 2021
ede13d4
Reconfigure FactorMuE for jit compilation
EWeinstein Feb 21, 2021
1dcd4a8
Debug jit compilation for FactorMuE.
EWeinstein Feb 21, 2021
07acd22
Beta annealing in FactorMuE.
EWeinstein Feb 21, 2021
1506697
Evaluate train test elbo and perplexity for ProfileHMM
EWeinstein Feb 21, 2021
a4a620e
Heldout likelihood evaluation for factormue.
EWeinstein Feb 21, 2021
26061b7
switch to local elbo evaluation and add more options to factor example.
EWeinstein Feb 25, 2021
33730bd
Cuda option.
EWeinstein Feb 25, 2021
2304b0e
pin memory option
EWeinstein Feb 25, 2021
04e9a22
Fix data tensor initialization.
EWeinstein Feb 25, 2021
e07cca7
Move data to cuda.
EWeinstein Feb 26, 2021
fd52f5b
Transfer results back to cpu.
EWeinstein Feb 26, 2021
15e4b17
Move results back to cpu.
EWeinstein Feb 26, 2021
028e305
Move more results to cpu.
EWeinstein Feb 26, 2021
2311edc
Speed up initialization.
EWeinstein Feb 26, 2021
8b02abe
Move to device in generator.
EWeinstein Feb 26, 2021
a7b452b
Adjust cuda device transfer.
EWeinstein Feb 26, 2021
490be6c
Move data to device for now for ease.
EWeinstein Feb 26, 2021
c6e7f2d
Try another way of fixing cuda error.
EWeinstein Feb 26, 2021
f550447
Try disabling device transfer in dataloader entirely.
EWeinstein Feb 26, 2021
ef943cc
Disable device handling entirely.
EWeinstein Feb 26, 2021
ec9b216
Add back in device handling in random split
EWeinstein Feb 26, 2021
c247c99
Try moving data lengths to cuda?
EWeinstein Feb 26, 2021
50e2149
Try another combination of device calls.
EWeinstein Feb 26, 2021
f06a1b4
Try adjusting generator again.
EWeinstein Feb 26, 2021
a455ba8
Try removing generator statement entirely.
EWeinstein Feb 26, 2021
511b51e
Try removing data util random split call.
EWeinstein Feb 26, 2021
1301e97
Clean up comments.
EWeinstein Feb 26, 2021
690dae8
Fix embedding ordering.
EWeinstein Feb 27, 2021
999bbe2
Update profile HMM example.
EWeinstein Feb 27, 2021
1fcbd67
Update tests, improve alphabet handling.
EWeinstein Feb 27, 2021
aa2850d
Adjust cuda defaults.
EWeinstein Feb 27, 2021
23ddc45
Move back to cpu for plotting.
EWeinstein Feb 27, 2021
f70115e
Merge remote-tracking branch 'upstream/dev' into mue
EWeinstein Mar 15, 2021
ca345f0
Documentation for profile hmm model.
EWeinstein Mar 15, 2021
cdea31d
Docs for FactorMuE model.
EWeinstein Mar 16, 2021
d8c9546
Cleaned up docs.
EWeinstein Mar 16, 2021
8542820
Tutorials.
EWeinstein Mar 16, 2021
5199844
Add example run scripts.
EWeinstein Mar 16, 2021
0085283
Make format changes.
EWeinstein Mar 16, 2021
b708557
Fix example boolean inputs
EWeinstein Mar 16, 2021
a4fc2e7
Wording edit.
EWeinstein Mar 16, 2021
3089861
Add jit and cuda calls to test_examples.
EWeinstein Mar 17, 2021
aa388a3
Add stop codon handling
EWeinstein Mar 17, 2021
df10c7b
Remove old length modeling mechanism in favor of using stop symbols
EWeinstein Mar 17, 2021
767d9e0
Option to keep data on cpu
EWeinstein Mar 17, 2021
7617acb
check device of seq data
EWeinstein Mar 17, 2021
839adf5
CPU storage option for profile HMM
EWeinstein Mar 17, 2021
4dcf385
Update example tests
EWeinstein Mar 17, 2021
2e192a8
Addressed fritzo comments
EWeinstein Mar 22, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pyro/_version.py
processed
raw
*.pkl
*.fasta
baseline_net_q1.pth
cvae_net_q1.pth
cvae_plot_q1.png
Expand Down
44 changes: 44 additions & 0 deletions docs/source/contrib.mue.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
Biological Sequence Models with MuE
===================================
.. automodule:: pyro.contrib.mue

.. warning:: Code in ``pyro.contrib.mue`` is under development.
This code makes no guarantee about maintaining backwards compatibility.

``pyro.contrib.mue`` provides modeling tools for working with biological
sequence data. In particular it implements MuE distributions, which are used as
a fully generative alternative to multiple sequence alignment-based
preprocessing.

Reference:
MuE models were described in Weinstein and Marks (2021),
https://www.biorxiv.org/content/10.1101/2020.07.31.231381v2.

Example MuE Models
------------------
.. automodule:: pyro.contrib.mue.models
:members:
:show-inheritance:
:member-order: bysource

State Arrangers for Parameterizing MuEs
---------------------------------------
.. automodule:: pyro.contrib.mue.statearrangers
:members:
:show-inheritance:
:member-order: bysource

Missing or Variable Length Data HMM
-----------------------------------
.. automodule:: pyro.contrib.mue.missingdatahmm
:members:
:show-inheritance:
:member-order: bysource


Biosequence Dataset Loading
---------------------------
.. automodule:: pyro.contrib.mue.dataloaders
:members:
:show-inheritance:
:member-order: bysource
4 changes: 2 additions & 2 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Pyro Documentation
:caption: Pyro Core:

getting_started
primitives
primitives
inference
distributions
parameters
Expand All @@ -39,6 +39,7 @@ Pyro Documentation
contrib.funsor
contrib.gp
contrib.minipyro
contrib.mue
contrib.oed
contrib.randomvariable
contrib.timeseries
Expand All @@ -52,4 +53,3 @@ Indices and tables
* :ref:`search`

.. * :ref:`modindex`

288 changes: 288 additions & 0 deletions examples/contrib/mue/FactorMuE.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,288 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

"""
A probabilistic PCA model with a MuE observation, called a 'FactorMuE' model
[1]. This is a generative model of variable-length biological sequences (e.g.
proteins) which does not require preprocessing the data by building a
multiple sequence alignment. It can be used to infer a latent representation
of sequences and the principal components of sequence variation, while
accounting for alignment uncertainty.

An example dataset consisting of proteins similar to the human papillomavirus E6
protein, collected from a non-redundant sequence dataset using jackhmmer, can
be found at
https://github.com/debbiemarkslab/MuE/blob/master/models/examples/ve6_full.fasta

Example run:
python FactorMuE.py -f PATH/ve6_full.fasta --z-dim 2 -b 10 -M 174 -D 25
--indel-prior-bias 10. --anneal 5 -e 15 -lr 0.01 --z-prior Laplace
--jit --cuda
This should take about 8 minutes to run on a GPU. The latent space should show
multiple small clusters, and the perplexity should be around 4.

Reference:
[1] E. N. Weinstein, D. S. Marks (2021)
"Generative probabilistic biological sequence models that account for
mutational variability"
https://www.biorxiv.org/content/10.1101/2020.07.31.231381v2.full.pdf
"""

import argparse
import datetime
import json
import os

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.optim import Adam

import pyro
from pyro.contrib.mue.dataloaders import BiosequenceDataset
from pyro.contrib.mue.models import FactorMuE
from pyro.optim import MultiStepLR


def generate_data(small_test, include_stop, device):
"""Generate mini example dataset."""
if small_test:
mult_dat = 1
else:
mult_dat = 10

seqs = ['BABBA']*mult_dat + ['BAAB']*mult_dat + ['BABBB']*mult_dat
dataset = BiosequenceDataset(seqs, 'list', 'AB', include_stop=include_stop,
device=device)

return dataset


def main(args):

# Load dataset.
if args.cpu_data and args.cuda:
device = torch.device('cpu')
else:
device = None
if args.test:
dataset = generate_data(args.small, args.include_stop, device)
else:
dataset = BiosequenceDataset(args.file, 'fasta', args.alphabet,
include_stop=args.include_stop,
device=device)
args.batch_size = min([dataset.data_size, args.batch_size])
if args.split > 0.:
# Train test split.
heldout_num = int(np.ceil(args.split*len(dataset)))
data_lengths = [len(dataset) - heldout_num, heldout_num]
# Specific data split seed, for comparability across models and
# parameter initializations.
pyro.set_rng_seed(args.rng_data_seed)
indices = torch.randperm(sum(data_lengths)).tolist()
dataset_train, dataset_test = [
torch.utils.data.Subset(dataset, indices[(offset - length):offset])
for offset, length in zip(torch._utils._accumulate(data_lengths),
data_lengths)]
else:
dataset_train = dataset
dataset_test = None

# Training seed.
pyro.set_rng_seed(args.rng_seed)

# Construct model.
model = FactorMuE(dataset.max_length, dataset.alphabet_length,
args.z_dim,
batch_size=args.batch_size,
latent_seq_length=args.latent_seq_length,
indel_factor_dependence=args.indel_factor,
indel_prior_scale=args.indel_prior_scale,
indel_prior_bias=args.indel_prior_bias,
inverse_temp_prior=args.inverse_temp_prior,
weights_prior_scale=args.weights_prior_scale,
offset_prior_scale=args.offset_prior_scale,
z_prior_distribution=args.z_prior,
ARD_prior=args.ARD_prior,
substitution_matrix=(not args.no_substitution_matrix),
substitution_prior_scale=args.substitution_prior_scale,
latent_alphabet_length=args.latent_alphabet,
cuda=args.cuda,
pin_memory=args.pin_mem)

# Infer with SVI.
scheduler = MultiStepLR({'optimizer': Adam,
'optim_args': {'lr': args.learning_rate},
'milestones': json.loads(args.milestones),
'gamma': args.learning_gamma})
n_epochs = args.n_epochs
losses = model.fit_svi(dataset_train, n_epochs, args.anneal,
args.batch_size, scheduler, args.jit)

# Evaluate.
train_lp, test_lp, train_perplex, test_perplex = model.evaluate(
dataset_train, dataset_test, args.jit)
print('train logp: {} perplex: {}'.format(train_lp, train_perplex))
print('test logp: {} perplex: {}'.format(test_lp, test_perplex))

# Get latent space embedding.
z_locs, z_scales = model.embed(dataset)

# Plot and save.
time_stamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
if not args.no_plots:
plt.figure(figsize=(6, 6))
plt.plot(losses)
plt.xlabel('step')
plt.ylabel('loss')
if not args.no_save:
plt.savefig(os.path.join(
args.out_folder,
'FactorMuE_plot.loss_{}.pdf'.format(time_stamp)))

plt.figure(figsize=(6, 6))
plt.scatter(z_locs[:, 0], z_locs[:, 1])
plt.xlabel(r'$z_1$')
plt.ylabel(r'$z_2$')
if not args.no_save:
plt.savefig(os.path.join(
args.out_folder,
'FactorMuE_plot.latent_{}.pdf'.format(time_stamp)))

if not args.indel_factor:
# Plot indel parameters. See statearrangers.py for details on the
# r and u parameters.
plt.figure(figsize=(6, 6))
insert = pyro.param("insert_q_mn").detach()
insert_expect = torch.exp(insert - insert.logsumexp(-1, True))
plt.plot(insert_expect[:, :, 1].cpu().numpy())
plt.xlabel('position')
plt.ylabel('probability of insert')
plt.legend([r'$r_0$', r'$r_1$', r'$r_2$'])
if not args.no_save:
plt.savefig(os.path.join(
args.out_folder,
'FactorMuE_plot.insert_prob_{}.pdf'.format(time_stamp)))
plt.figure(figsize=(6, 6))
delete = pyro.param("delete_q_mn").detach()
delete_expect = torch.exp(delete - delete.logsumexp(-1, True))
plt.plot(delete_expect[:, :, 1].cpu().numpy())
plt.xlabel('position')
plt.ylabel('probability of delete')
plt.legend([r'$u_0$', r'$u_1$', r'$u_2$'])
if not args.no_save:
plt.savefig(os.path.join(
args.out_folder,
'FactorMuE_plot.delete_prob_{}.pdf'.format(time_stamp)))

if not args.no_save:
pyro.get_param_store().save(os.path.join(
args.out_folder,
'FactorMuE_results.params_{}.out'.format(time_stamp)))
with open(os.path.join(
args.out_folder,
'FactorMuE_results.evaluation_{}.txt'.format(time_stamp)),
'w') as ow:
ow.write('train_lp,test_lp,train_perplex,test_perplex\n')
ow.write('{},{},{},{}\n'.format(train_lp, test_lp, train_perplex,
test_perplex))
np.savetxt(os.path.join(
args.out_folder,
'FactorMuE_results.embed_loc_{}.txt'.format(
time_stamp)),
z_locs.cpu().numpy())
np.savetxt(os.path.join(
args.out_folder,
'FactorMuE_results.embed_scale_{}.txt'.format(
time_stamp)),
z_scales.cpu().numpy())
with open(os.path.join(
args.out_folder,
'FactorMuE_results.input_{}.txt'.format(time_stamp)),
'w') as ow:
ow.write('[args]\n')
for elem in list(args.__dict__.keys()):
ow.write('{} = {}\n'.format(elem, args.__getattribute__(elem)))


if __name__ == '__main__':
# Parse command line arguments.
parser = argparse.ArgumentParser(description="Factor MuE model.")
parser.add_argument("--test", action='store_true', default=False,
help='Run with generated example dataset.')
parser.add_argument("--small", action='store_true', default=False,
help='Run with small example dataset.')
parser.add_argument("-r", "--rng-seed", default=0, type=int)
parser.add_argument("--rng-data-seed", default=0, type=int)
parser.add_argument("-f", "--file", default=None, type=str,
help='Input file (fasta format).')
parser.add_argument("-a", "--alphabet", default='amino-acid',
help='Alphabet (amino-acid OR dna OR ATGC ...).')
parser.add_argument("-zdim", "--z-dim", default=2, type=int,
help='z space dimension.')
parser.add_argument("-b", "--batch-size", default=10, type=int,
help='Batch size.')
parser.add_argument("-M", "--latent-seq-length", default=None, type=int,
help='Latent sequence length.')
parser.add_argument("-idfac", "--indel-factor", default=False,
action='store_true',
help='Indel parameters depend on latent variable.')
parser.add_argument("-zdist", "--z-prior", default='Normal',
help='Latent prior distribution (normal or Laplace).')
parser.add_argument("-ard", "--ARD-prior", default=False,
action='store_true',
help='Use automatic relevance detection prior.')
parser.add_argument("--no-substitution-matrix", default=False,
action='store_true',
help='Do not use substitution matrix.')
parser.add_argument("-D", "--latent-alphabet", default=None, type=int,
help='Latent alphabet length.')
parser.add_argument("--include-stop", default=False, action='store_true',
help='Include stop symbol at the end of each sequence.')
parser.add_argument("--indel-prior-scale", default=1., type=float,
help=('Indel prior scale parameter ' +
'(when indel-factor=False).'))
parser.add_argument("--indel-prior-bias", default=10., type=float,
help='Indel prior bias parameter.')
parser.add_argument("--inverse-temp-prior", default=100., type=float,
help='Inverse temperature prior mean.')
parser.add_argument("--weights-prior-scale", default=1., type=float,
help='Factor parameter prior scale.')
parser.add_argument("--offset-prior-scale", default=1., type=float,
help='Offset parameter prior scale.')
parser.add_argument("--substitution-prior-scale", default=10., type=float,
help='Substitution matrix prior scale.')
parser.add_argument("-lr", "--learning-rate", default=0.001, type=float,
help='Learning rate for Adam optimizer.')
parser.add_argument("--milestones", default='[]', type=str,
help='Milestones for multistage learning rate.')
parser.add_argument("--learning-gamma", default=0.5, type=float,
help='Gamma parameter for multistage learning rate.')
parser.add_argument("-e", "--n-epochs", default=10, type=int,
help='Number of epochs of training.')
parser.add_argument("--anneal", default=0., type=float,
help='Number of epochs to anneal beta over.')
parser.add_argument("--no-plots", default=False, action='store_true',
help='Make plots.')
parser.add_argument("--no-save", default=False, action='store_true',
help='Do not save plots and results.')
parser.add_argument("-outf", "--out-folder", default='.',
help='Folder to save plots.')
parser.add_argument("--split", default=0.2, type=float,
help=('Fraction of dataset to holdout for testing'))
parser.add_argument("--jit", default=False, action='store_true',
help='JIT compile the ELBO.')
parser.add_argument("--cuda", default=False, action='store_true',
help='Use GPU.')
parser.add_argument("--cpu-data", default=False, action='store_true',
help='Keep data on CPU (for large datasets).')
parser.add_argument("--pin-mem", default=False, action='store_true',
help='Use pin_memory for faster CPU to GPU transfer.')
args = parser.parse_args()

if args.cuda:
torch.set_default_tensor_type(torch.cuda.DoubleTensor)
else:
torch.set_default_dtype(torch.float64)

main(args)
Loading