Skip to content

Commit

Permalink
MuE distributions for Pyro. (#2728)
Browse files Browse the repository at this point in the history
  • Loading branch information
EWeinstein committed Mar 23, 2021
1 parent 0c30148 commit 6c8d887
Show file tree
Hide file tree
Showing 18 changed files with 2,333 additions and 2 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pyro/_version.py
processed
raw
*.pkl
*.fasta
baseline_net_q1.pth
cvae_net_q1.pth
cvae_plot_q1.png
Expand Down
44 changes: 44 additions & 0 deletions docs/source/contrib.mue.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
Biological Sequence Models with MuE
===================================
.. automodule:: pyro.contrib.mue

.. warning:: Code in ``pyro.contrib.mue`` is under development.
This code makes no guarantee about maintaining backwards compatibility.

``pyro.contrib.mue`` provides modeling tools for working with biological
sequence data. In particular it implements MuE distributions, which are used as
a fully generative alternative to multiple sequence alignment-based
preprocessing.

Reference:
MuE models were described in Weinstein and Marks (2021),
https://www.biorxiv.org/content/10.1101/2020.07.31.231381v2.

Example MuE Models
------------------
.. automodule:: pyro.contrib.mue.models
:members:
:show-inheritance:
:member-order: bysource

State Arrangers for Parameterizing MuEs
---------------------------------------
.. automodule:: pyro.contrib.mue.statearrangers
:members:
:show-inheritance:
:member-order: bysource

Missing or Variable Length Data HMM
-----------------------------------
.. automodule:: pyro.contrib.mue.missingdatahmm
:members:
:show-inheritance:
:member-order: bysource


Biosequence Dataset Loading
---------------------------
.. automodule:: pyro.contrib.mue.dataloaders
:members:
:show-inheritance:
:member-order: bysource
4 changes: 2 additions & 2 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Pyro Documentation
:caption: Pyro Core:

getting_started
primitives
primitives
inference
distributions
parameters
Expand All @@ -39,6 +39,7 @@ Pyro Documentation
contrib.funsor
contrib.gp
contrib.minipyro
contrib.mue
contrib.oed
contrib.randomvariable
contrib.timeseries
Expand All @@ -52,4 +53,3 @@ Indices and tables
* :ref:`search`

.. * :ref:`modindex`
288 changes: 288 additions & 0 deletions examples/contrib/mue/FactorMuE.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,288 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

"""
A probabilistic PCA model with a MuE observation, called a 'FactorMuE' model
[1]. This is a generative model of variable-length biological sequences (e.g.
proteins) which does not require preprocessing the data by building a
multiple sequence alignment. It can be used to infer a latent representation
of sequences and the principal components of sequence variation, while
accounting for alignment uncertainty.
An example dataset consisting of proteins similar to the human papillomavirus E6
protein, collected from a non-redundant sequence dataset using jackhmmer, can
be found at
https://github.com/debbiemarkslab/MuE/blob/master/models/examples/ve6_full.fasta
Example run:
python FactorMuE.py -f PATH/ve6_full.fasta --z-dim 2 -b 10 -M 174 -D 25
--indel-prior-bias 10. --anneal 5 -e 15 -lr 0.01 --z-prior Laplace
--jit --cuda
This should take about 8 minutes to run on a GPU. The latent space should show
multiple small clusters, and the perplexity should be around 4.
Reference:
[1] E. N. Weinstein, D. S. Marks (2021)
"Generative probabilistic biological sequence models that account for
mutational variability"
https://www.biorxiv.org/content/10.1101/2020.07.31.231381v2.full.pdf
"""

import argparse
import datetime
import json
import os

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.optim import Adam

import pyro
from pyro.contrib.mue.dataloaders import BiosequenceDataset
from pyro.contrib.mue.models import FactorMuE
from pyro.optim import MultiStepLR


def generate_data(small_test, include_stop, device):
"""Generate mini example dataset."""
if small_test:
mult_dat = 1
else:
mult_dat = 10

seqs = ['BABBA']*mult_dat + ['BAAB']*mult_dat + ['BABBB']*mult_dat
dataset = BiosequenceDataset(seqs, 'list', 'AB', include_stop=include_stop,
device=device)

return dataset


def main(args):

# Load dataset.
if args.cpu_data and args.cuda:
device = torch.device('cpu')
else:
device = None
if args.test:
dataset = generate_data(args.small, args.include_stop, device)
else:
dataset = BiosequenceDataset(args.file, 'fasta', args.alphabet,
include_stop=args.include_stop,
device=device)
args.batch_size = min([dataset.data_size, args.batch_size])
if args.split > 0.:
# Train test split.
heldout_num = int(np.ceil(args.split*len(dataset)))
data_lengths = [len(dataset) - heldout_num, heldout_num]
# Specific data split seed, for comparability across models and
# parameter initializations.
pyro.set_rng_seed(args.rng_data_seed)
indices = torch.randperm(sum(data_lengths)).tolist()
dataset_train, dataset_test = [
torch.utils.data.Subset(dataset, indices[(offset - length):offset])
for offset, length in zip(torch._utils._accumulate(data_lengths),
data_lengths)]
else:
dataset_train = dataset
dataset_test = None

# Training seed.
pyro.set_rng_seed(args.rng_seed)

# Construct model.
model = FactorMuE(dataset.max_length, dataset.alphabet_length,
args.z_dim,
batch_size=args.batch_size,
latent_seq_length=args.latent_seq_length,
indel_factor_dependence=args.indel_factor,
indel_prior_scale=args.indel_prior_scale,
indel_prior_bias=args.indel_prior_bias,
inverse_temp_prior=args.inverse_temp_prior,
weights_prior_scale=args.weights_prior_scale,
offset_prior_scale=args.offset_prior_scale,
z_prior_distribution=args.z_prior,
ARD_prior=args.ARD_prior,
substitution_matrix=(not args.no_substitution_matrix),
substitution_prior_scale=args.substitution_prior_scale,
latent_alphabet_length=args.latent_alphabet,
cuda=args.cuda,
pin_memory=args.pin_mem)

# Infer with SVI.
scheduler = MultiStepLR({'optimizer': Adam,
'optim_args': {'lr': args.learning_rate},
'milestones': json.loads(args.milestones),
'gamma': args.learning_gamma})
n_epochs = args.n_epochs
losses = model.fit_svi(dataset_train, n_epochs, args.anneal,
args.batch_size, scheduler, args.jit)

# Evaluate.
train_lp, test_lp, train_perplex, test_perplex = model.evaluate(
dataset_train, dataset_test, args.jit)
print('train logp: {} perplex: {}'.format(train_lp, train_perplex))
print('test logp: {} perplex: {}'.format(test_lp, test_perplex))

# Get latent space embedding.
z_locs, z_scales = model.embed(dataset)

# Plot and save.
time_stamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
if not args.no_plots:
plt.figure(figsize=(6, 6))
plt.plot(losses)
plt.xlabel('step')
plt.ylabel('loss')
if not args.no_save:
plt.savefig(os.path.join(
args.out_folder,
'FactorMuE_plot.loss_{}.pdf'.format(time_stamp)))

plt.figure(figsize=(6, 6))
plt.scatter(z_locs[:, 0], z_locs[:, 1])
plt.xlabel(r'$z_1$')
plt.ylabel(r'$z_2$')
if not args.no_save:
plt.savefig(os.path.join(
args.out_folder,
'FactorMuE_plot.latent_{}.pdf'.format(time_stamp)))

if not args.indel_factor:
# Plot indel parameters. See statearrangers.py for details on the
# r and u parameters.
plt.figure(figsize=(6, 6))
insert = pyro.param("insert_q_mn").detach()
insert_expect = torch.exp(insert - insert.logsumexp(-1, True))
plt.plot(insert_expect[:, :, 1].cpu().numpy())
plt.xlabel('position')
plt.ylabel('probability of insert')
plt.legend([r'$r_0$', r'$r_1$', r'$r_2$'])
if not args.no_save:
plt.savefig(os.path.join(
args.out_folder,
'FactorMuE_plot.insert_prob_{}.pdf'.format(time_stamp)))
plt.figure(figsize=(6, 6))
delete = pyro.param("delete_q_mn").detach()
delete_expect = torch.exp(delete - delete.logsumexp(-1, True))
plt.plot(delete_expect[:, :, 1].cpu().numpy())
plt.xlabel('position')
plt.ylabel('probability of delete')
plt.legend([r'$u_0$', r'$u_1$', r'$u_2$'])
if not args.no_save:
plt.savefig(os.path.join(
args.out_folder,
'FactorMuE_plot.delete_prob_{}.pdf'.format(time_stamp)))

if not args.no_save:
pyro.get_param_store().save(os.path.join(
args.out_folder,
'FactorMuE_results.params_{}.out'.format(time_stamp)))
with open(os.path.join(
args.out_folder,
'FactorMuE_results.evaluation_{}.txt'.format(time_stamp)),
'w') as ow:
ow.write('train_lp,test_lp,train_perplex,test_perplex\n')
ow.write('{},{},{},{}\n'.format(train_lp, test_lp, train_perplex,
test_perplex))
np.savetxt(os.path.join(
args.out_folder,
'FactorMuE_results.embed_loc_{}.txt'.format(
time_stamp)),
z_locs.cpu().numpy())
np.savetxt(os.path.join(
args.out_folder,
'FactorMuE_results.embed_scale_{}.txt'.format(
time_stamp)),
z_scales.cpu().numpy())
with open(os.path.join(
args.out_folder,
'FactorMuE_results.input_{}.txt'.format(time_stamp)),
'w') as ow:
ow.write('[args]\n')
for elem in list(args.__dict__.keys()):
ow.write('{} = {}\n'.format(elem, args.__getattribute__(elem)))


if __name__ == '__main__':
# Parse command line arguments.
parser = argparse.ArgumentParser(description="Factor MuE model.")
parser.add_argument("--test", action='store_true', default=False,
help='Run with generated example dataset.')
parser.add_argument("--small", action='store_true', default=False,
help='Run with small example dataset.')
parser.add_argument("-r", "--rng-seed", default=0, type=int)
parser.add_argument("--rng-data-seed", default=0, type=int)
parser.add_argument("-f", "--file", default=None, type=str,
help='Input file (fasta format).')
parser.add_argument("-a", "--alphabet", default='amino-acid',
help='Alphabet (amino-acid OR dna OR ATGC ...).')
parser.add_argument("-zdim", "--z-dim", default=2, type=int,
help='z space dimension.')
parser.add_argument("-b", "--batch-size", default=10, type=int,
help='Batch size.')
parser.add_argument("-M", "--latent-seq-length", default=None, type=int,
help='Latent sequence length.')
parser.add_argument("-idfac", "--indel-factor", default=False,
action='store_true',
help='Indel parameters depend on latent variable.')
parser.add_argument("-zdist", "--z-prior", default='Normal',
help='Latent prior distribution (normal or Laplace).')
parser.add_argument("-ard", "--ARD-prior", default=False,
action='store_true',
help='Use automatic relevance detection prior.')
parser.add_argument("--no-substitution-matrix", default=False,
action='store_true',
help='Do not use substitution matrix.')
parser.add_argument("-D", "--latent-alphabet", default=None, type=int,
help='Latent alphabet length.')
parser.add_argument("--include-stop", default=False, action='store_true',
help='Include stop symbol at the end of each sequence.')
parser.add_argument("--indel-prior-scale", default=1., type=float,
help=('Indel prior scale parameter ' +
'(when indel-factor=False).'))
parser.add_argument("--indel-prior-bias", default=10., type=float,
help='Indel prior bias parameter.')
parser.add_argument("--inverse-temp-prior", default=100., type=float,
help='Inverse temperature prior mean.')
parser.add_argument("--weights-prior-scale", default=1., type=float,
help='Factor parameter prior scale.')
parser.add_argument("--offset-prior-scale", default=1., type=float,
help='Offset parameter prior scale.')
parser.add_argument("--substitution-prior-scale", default=10., type=float,
help='Substitution matrix prior scale.')
parser.add_argument("-lr", "--learning-rate", default=0.001, type=float,
help='Learning rate for Adam optimizer.')
parser.add_argument("--milestones", default='[]', type=str,
help='Milestones for multistage learning rate.')
parser.add_argument("--learning-gamma", default=0.5, type=float,
help='Gamma parameter for multistage learning rate.')
parser.add_argument("-e", "--n-epochs", default=10, type=int,
help='Number of epochs of training.')
parser.add_argument("--anneal", default=0., type=float,
help='Number of epochs to anneal beta over.')
parser.add_argument("--no-plots", default=False, action='store_true',
help='Make plots.')
parser.add_argument("--no-save", default=False, action='store_true',
help='Do not save plots and results.')
parser.add_argument("-outf", "--out-folder", default='.',
help='Folder to save plots.')
parser.add_argument("--split", default=0.2, type=float,
help=('Fraction of dataset to holdout for testing'))
parser.add_argument("--jit", default=False, action='store_true',
help='JIT compile the ELBO.')
parser.add_argument("--cuda", default=False, action='store_true',
help='Use GPU.')
parser.add_argument("--cpu-data", default=False, action='store_true',
help='Keep data on CPU (for large datasets).')
parser.add_argument("--pin-mem", default=False, action='store_true',
help='Use pin_memory for faster CPU to GPU transfer.')
args = parser.parse_args()

if args.cuda:
torch.set_default_tensor_type(torch.cuda.DoubleTensor)
else:
torch.set_default_dtype(torch.float64)

main(args)
Loading

0 comments on commit 6c8d887

Please sign in to comment.