-
-
Notifications
You must be signed in to change notification settings - Fork 984
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
0c30148
commit 6c8d887
Showing
18 changed files
with
2,333 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,7 @@ pyro/_version.py | |
processed | ||
raw | ||
*.pkl | ||
*.fasta | ||
baseline_net_q1.pth | ||
cvae_net_q1.pth | ||
cvae_plot_q1.png | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
Biological Sequence Models with MuE | ||
=================================== | ||
.. automodule:: pyro.contrib.mue | ||
|
||
.. warning:: Code in ``pyro.contrib.mue`` is under development. | ||
This code makes no guarantee about maintaining backwards compatibility. | ||
|
||
``pyro.contrib.mue`` provides modeling tools for working with biological | ||
sequence data. In particular it implements MuE distributions, which are used as | ||
a fully generative alternative to multiple sequence alignment-based | ||
preprocessing. | ||
|
||
Reference: | ||
MuE models were described in Weinstein and Marks (2021), | ||
https://www.biorxiv.org/content/10.1101/2020.07.31.231381v2. | ||
|
||
Example MuE Models | ||
------------------ | ||
.. automodule:: pyro.contrib.mue.models | ||
:members: | ||
:show-inheritance: | ||
:member-order: bysource | ||
|
||
State Arrangers for Parameterizing MuEs | ||
--------------------------------------- | ||
.. automodule:: pyro.contrib.mue.statearrangers | ||
:members: | ||
:show-inheritance: | ||
:member-order: bysource | ||
|
||
Missing or Variable Length Data HMM | ||
----------------------------------- | ||
.. automodule:: pyro.contrib.mue.missingdatahmm | ||
:members: | ||
:show-inheritance: | ||
:member-order: bysource | ||
|
||
|
||
Biosequence Dataset Loading | ||
--------------------------- | ||
.. automodule:: pyro.contrib.mue.dataloaders | ||
:members: | ||
:show-inheritance: | ||
:member-order: bysource |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,288 @@ | ||
# Copyright Contributors to the Pyro project. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
""" | ||
A probabilistic PCA model with a MuE observation, called a 'FactorMuE' model | ||
[1]. This is a generative model of variable-length biological sequences (e.g. | ||
proteins) which does not require preprocessing the data by building a | ||
multiple sequence alignment. It can be used to infer a latent representation | ||
of sequences and the principal components of sequence variation, while | ||
accounting for alignment uncertainty. | ||
An example dataset consisting of proteins similar to the human papillomavirus E6 | ||
protein, collected from a non-redundant sequence dataset using jackhmmer, can | ||
be found at | ||
https://github.com/debbiemarkslab/MuE/blob/master/models/examples/ve6_full.fasta | ||
Example run: | ||
python FactorMuE.py -f PATH/ve6_full.fasta --z-dim 2 -b 10 -M 174 -D 25 | ||
--indel-prior-bias 10. --anneal 5 -e 15 -lr 0.01 --z-prior Laplace | ||
--jit --cuda | ||
This should take about 8 minutes to run on a GPU. The latent space should show | ||
multiple small clusters, and the perplexity should be around 4. | ||
Reference: | ||
[1] E. N. Weinstein, D. S. Marks (2021) | ||
"Generative probabilistic biological sequence models that account for | ||
mutational variability" | ||
https://www.biorxiv.org/content/10.1101/2020.07.31.231381v2.full.pdf | ||
""" | ||
|
||
import argparse | ||
import datetime | ||
import json | ||
import os | ||
|
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import torch | ||
from torch.optim import Adam | ||
|
||
import pyro | ||
from pyro.contrib.mue.dataloaders import BiosequenceDataset | ||
from pyro.contrib.mue.models import FactorMuE | ||
from pyro.optim import MultiStepLR | ||
|
||
|
||
def generate_data(small_test, include_stop, device): | ||
"""Generate mini example dataset.""" | ||
if small_test: | ||
mult_dat = 1 | ||
else: | ||
mult_dat = 10 | ||
|
||
seqs = ['BABBA']*mult_dat + ['BAAB']*mult_dat + ['BABBB']*mult_dat | ||
dataset = BiosequenceDataset(seqs, 'list', 'AB', include_stop=include_stop, | ||
device=device) | ||
|
||
return dataset | ||
|
||
|
||
def main(args): | ||
|
||
# Load dataset. | ||
if args.cpu_data and args.cuda: | ||
device = torch.device('cpu') | ||
else: | ||
device = None | ||
if args.test: | ||
dataset = generate_data(args.small, args.include_stop, device) | ||
else: | ||
dataset = BiosequenceDataset(args.file, 'fasta', args.alphabet, | ||
include_stop=args.include_stop, | ||
device=device) | ||
args.batch_size = min([dataset.data_size, args.batch_size]) | ||
if args.split > 0.: | ||
# Train test split. | ||
heldout_num = int(np.ceil(args.split*len(dataset))) | ||
data_lengths = [len(dataset) - heldout_num, heldout_num] | ||
# Specific data split seed, for comparability across models and | ||
# parameter initializations. | ||
pyro.set_rng_seed(args.rng_data_seed) | ||
indices = torch.randperm(sum(data_lengths)).tolist() | ||
dataset_train, dataset_test = [ | ||
torch.utils.data.Subset(dataset, indices[(offset - length):offset]) | ||
for offset, length in zip(torch._utils._accumulate(data_lengths), | ||
data_lengths)] | ||
else: | ||
dataset_train = dataset | ||
dataset_test = None | ||
|
||
# Training seed. | ||
pyro.set_rng_seed(args.rng_seed) | ||
|
||
# Construct model. | ||
model = FactorMuE(dataset.max_length, dataset.alphabet_length, | ||
args.z_dim, | ||
batch_size=args.batch_size, | ||
latent_seq_length=args.latent_seq_length, | ||
indel_factor_dependence=args.indel_factor, | ||
indel_prior_scale=args.indel_prior_scale, | ||
indel_prior_bias=args.indel_prior_bias, | ||
inverse_temp_prior=args.inverse_temp_prior, | ||
weights_prior_scale=args.weights_prior_scale, | ||
offset_prior_scale=args.offset_prior_scale, | ||
z_prior_distribution=args.z_prior, | ||
ARD_prior=args.ARD_prior, | ||
substitution_matrix=(not args.no_substitution_matrix), | ||
substitution_prior_scale=args.substitution_prior_scale, | ||
latent_alphabet_length=args.latent_alphabet, | ||
cuda=args.cuda, | ||
pin_memory=args.pin_mem) | ||
|
||
# Infer with SVI. | ||
scheduler = MultiStepLR({'optimizer': Adam, | ||
'optim_args': {'lr': args.learning_rate}, | ||
'milestones': json.loads(args.milestones), | ||
'gamma': args.learning_gamma}) | ||
n_epochs = args.n_epochs | ||
losses = model.fit_svi(dataset_train, n_epochs, args.anneal, | ||
args.batch_size, scheduler, args.jit) | ||
|
||
# Evaluate. | ||
train_lp, test_lp, train_perplex, test_perplex = model.evaluate( | ||
dataset_train, dataset_test, args.jit) | ||
print('train logp: {} perplex: {}'.format(train_lp, train_perplex)) | ||
print('test logp: {} perplex: {}'.format(test_lp, test_perplex)) | ||
|
||
# Get latent space embedding. | ||
z_locs, z_scales = model.embed(dataset) | ||
|
||
# Plot and save. | ||
time_stamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") | ||
if not args.no_plots: | ||
plt.figure(figsize=(6, 6)) | ||
plt.plot(losses) | ||
plt.xlabel('step') | ||
plt.ylabel('loss') | ||
if not args.no_save: | ||
plt.savefig(os.path.join( | ||
args.out_folder, | ||
'FactorMuE_plot.loss_{}.pdf'.format(time_stamp))) | ||
|
||
plt.figure(figsize=(6, 6)) | ||
plt.scatter(z_locs[:, 0], z_locs[:, 1]) | ||
plt.xlabel(r'$z_1$') | ||
plt.ylabel(r'$z_2$') | ||
if not args.no_save: | ||
plt.savefig(os.path.join( | ||
args.out_folder, | ||
'FactorMuE_plot.latent_{}.pdf'.format(time_stamp))) | ||
|
||
if not args.indel_factor: | ||
# Plot indel parameters. See statearrangers.py for details on the | ||
# r and u parameters. | ||
plt.figure(figsize=(6, 6)) | ||
insert = pyro.param("insert_q_mn").detach() | ||
insert_expect = torch.exp(insert - insert.logsumexp(-1, True)) | ||
plt.plot(insert_expect[:, :, 1].cpu().numpy()) | ||
plt.xlabel('position') | ||
plt.ylabel('probability of insert') | ||
plt.legend([r'$r_0$', r'$r_1$', r'$r_2$']) | ||
if not args.no_save: | ||
plt.savefig(os.path.join( | ||
args.out_folder, | ||
'FactorMuE_plot.insert_prob_{}.pdf'.format(time_stamp))) | ||
plt.figure(figsize=(6, 6)) | ||
delete = pyro.param("delete_q_mn").detach() | ||
delete_expect = torch.exp(delete - delete.logsumexp(-1, True)) | ||
plt.plot(delete_expect[:, :, 1].cpu().numpy()) | ||
plt.xlabel('position') | ||
plt.ylabel('probability of delete') | ||
plt.legend([r'$u_0$', r'$u_1$', r'$u_2$']) | ||
if not args.no_save: | ||
plt.savefig(os.path.join( | ||
args.out_folder, | ||
'FactorMuE_plot.delete_prob_{}.pdf'.format(time_stamp))) | ||
|
||
if not args.no_save: | ||
pyro.get_param_store().save(os.path.join( | ||
args.out_folder, | ||
'FactorMuE_results.params_{}.out'.format(time_stamp))) | ||
with open(os.path.join( | ||
args.out_folder, | ||
'FactorMuE_results.evaluation_{}.txt'.format(time_stamp)), | ||
'w') as ow: | ||
ow.write('train_lp,test_lp,train_perplex,test_perplex\n') | ||
ow.write('{},{},{},{}\n'.format(train_lp, test_lp, train_perplex, | ||
test_perplex)) | ||
np.savetxt(os.path.join( | ||
args.out_folder, | ||
'FactorMuE_results.embed_loc_{}.txt'.format( | ||
time_stamp)), | ||
z_locs.cpu().numpy()) | ||
np.savetxt(os.path.join( | ||
args.out_folder, | ||
'FactorMuE_results.embed_scale_{}.txt'.format( | ||
time_stamp)), | ||
z_scales.cpu().numpy()) | ||
with open(os.path.join( | ||
args.out_folder, | ||
'FactorMuE_results.input_{}.txt'.format(time_stamp)), | ||
'w') as ow: | ||
ow.write('[args]\n') | ||
for elem in list(args.__dict__.keys()): | ||
ow.write('{} = {}\n'.format(elem, args.__getattribute__(elem))) | ||
|
||
|
||
if __name__ == '__main__': | ||
# Parse command line arguments. | ||
parser = argparse.ArgumentParser(description="Factor MuE model.") | ||
parser.add_argument("--test", action='store_true', default=False, | ||
help='Run with generated example dataset.') | ||
parser.add_argument("--small", action='store_true', default=False, | ||
help='Run with small example dataset.') | ||
parser.add_argument("-r", "--rng-seed", default=0, type=int) | ||
parser.add_argument("--rng-data-seed", default=0, type=int) | ||
parser.add_argument("-f", "--file", default=None, type=str, | ||
help='Input file (fasta format).') | ||
parser.add_argument("-a", "--alphabet", default='amino-acid', | ||
help='Alphabet (amino-acid OR dna OR ATGC ...).') | ||
parser.add_argument("-zdim", "--z-dim", default=2, type=int, | ||
help='z space dimension.') | ||
parser.add_argument("-b", "--batch-size", default=10, type=int, | ||
help='Batch size.') | ||
parser.add_argument("-M", "--latent-seq-length", default=None, type=int, | ||
help='Latent sequence length.') | ||
parser.add_argument("-idfac", "--indel-factor", default=False, | ||
action='store_true', | ||
help='Indel parameters depend on latent variable.') | ||
parser.add_argument("-zdist", "--z-prior", default='Normal', | ||
help='Latent prior distribution (normal or Laplace).') | ||
parser.add_argument("-ard", "--ARD-prior", default=False, | ||
action='store_true', | ||
help='Use automatic relevance detection prior.') | ||
parser.add_argument("--no-substitution-matrix", default=False, | ||
action='store_true', | ||
help='Do not use substitution matrix.') | ||
parser.add_argument("-D", "--latent-alphabet", default=None, type=int, | ||
help='Latent alphabet length.') | ||
parser.add_argument("--include-stop", default=False, action='store_true', | ||
help='Include stop symbol at the end of each sequence.') | ||
parser.add_argument("--indel-prior-scale", default=1., type=float, | ||
help=('Indel prior scale parameter ' + | ||
'(when indel-factor=False).')) | ||
parser.add_argument("--indel-prior-bias", default=10., type=float, | ||
help='Indel prior bias parameter.') | ||
parser.add_argument("--inverse-temp-prior", default=100., type=float, | ||
help='Inverse temperature prior mean.') | ||
parser.add_argument("--weights-prior-scale", default=1., type=float, | ||
help='Factor parameter prior scale.') | ||
parser.add_argument("--offset-prior-scale", default=1., type=float, | ||
help='Offset parameter prior scale.') | ||
parser.add_argument("--substitution-prior-scale", default=10., type=float, | ||
help='Substitution matrix prior scale.') | ||
parser.add_argument("-lr", "--learning-rate", default=0.001, type=float, | ||
help='Learning rate for Adam optimizer.') | ||
parser.add_argument("--milestones", default='[]', type=str, | ||
help='Milestones for multistage learning rate.') | ||
parser.add_argument("--learning-gamma", default=0.5, type=float, | ||
help='Gamma parameter for multistage learning rate.') | ||
parser.add_argument("-e", "--n-epochs", default=10, type=int, | ||
help='Number of epochs of training.') | ||
parser.add_argument("--anneal", default=0., type=float, | ||
help='Number of epochs to anneal beta over.') | ||
parser.add_argument("--no-plots", default=False, action='store_true', | ||
help='Make plots.') | ||
parser.add_argument("--no-save", default=False, action='store_true', | ||
help='Do not save plots and results.') | ||
parser.add_argument("-outf", "--out-folder", default='.', | ||
help='Folder to save plots.') | ||
parser.add_argument("--split", default=0.2, type=float, | ||
help=('Fraction of dataset to holdout for testing')) | ||
parser.add_argument("--jit", default=False, action='store_true', | ||
help='JIT compile the ELBO.') | ||
parser.add_argument("--cuda", default=False, action='store_true', | ||
help='Use GPU.') | ||
parser.add_argument("--cpu-data", default=False, action='store_true', | ||
help='Keep data on CPU (for large datasets).') | ||
parser.add_argument("--pin-mem", default=False, action='store_true', | ||
help='Use pin_memory for faster CPU to GPU transfer.') | ||
args = parser.parse_args() | ||
|
||
if args.cuda: | ||
torch.set_default_tensor_type(torch.cuda.DoubleTensor) | ||
else: | ||
torch.set_default_dtype(torch.float64) | ||
|
||
main(args) |
Oops, something went wrong.