From 6c8d887efd861042c41cedd5cd86c39ab1885a94 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Tue, 23 Mar 2021 09:03:29 -0400 Subject: [PATCH] MuE distributions for Pyro. (#2728) --- .gitignore | 1 + docs/source/contrib.mue.rst | 44 ++ docs/source/index.rst | 4 +- examples/contrib/mue/FactorMuE.py | 288 +++++++++ examples/contrib/mue/ProfileHMM.py | 231 +++++++ pyro/contrib/mue/__init__.py | 0 pyro/contrib/mue/dataloaders.py | 112 ++++ pyro/contrib/mue/missingdatahmm.py | 104 ++++ pyro/contrib/mue/models.py | 731 +++++++++++++++++++++++ pyro/contrib/mue/statearrangers.py | 209 +++++++ tests/contrib/mue/test_dataloaders.py | 69 +++ tests/contrib/mue/test_missingdatahmm.py | 173 ++++++ tests/contrib/mue/test_models.py | 93 +++ tests/contrib/mue/test_statearrangers.py | 234 ++++++++ tests/test_examples.py | 12 + tutorial/source/index.rst | 8 + tutorial/source/mue_factor.rst | 11 + tutorial/source/mue_profile.rst | 11 + 18 files changed, 2333 insertions(+), 2 deletions(-) create mode 100644 docs/source/contrib.mue.rst create mode 100644 examples/contrib/mue/FactorMuE.py create mode 100644 examples/contrib/mue/ProfileHMM.py create mode 100644 pyro/contrib/mue/__init__.py create mode 100644 pyro/contrib/mue/dataloaders.py create mode 100644 pyro/contrib/mue/missingdatahmm.py create mode 100644 pyro/contrib/mue/models.py create mode 100644 pyro/contrib/mue/statearrangers.py create mode 100644 tests/contrib/mue/test_dataloaders.py create mode 100644 tests/contrib/mue/test_missingdatahmm.py create mode 100644 tests/contrib/mue/test_models.py create mode 100644 tests/contrib/mue/test_statearrangers.py create mode 100644 tutorial/source/mue_factor.rst create mode 100644 tutorial/source/mue_profile.rst diff --git a/.gitignore b/.gitignore index 2154a3af56..2a2745a94d 100644 --- a/.gitignore +++ b/.gitignore @@ -19,6 +19,7 @@ pyro/_version.py processed raw *.pkl +*.fasta baseline_net_q1.pth cvae_net_q1.pth cvae_plot_q1.png diff --git a/docs/source/contrib.mue.rst b/docs/source/contrib.mue.rst new file mode 100644 index 0000000000..116279c3c6 --- /dev/null +++ b/docs/source/contrib.mue.rst @@ -0,0 +1,44 @@ +Biological Sequence Models with MuE +=================================== +.. automodule:: pyro.contrib.mue + +.. warning:: Code in ``pyro.contrib.mue`` is under development. + This code makes no guarantee about maintaining backwards compatibility. + +``pyro.contrib.mue`` provides modeling tools for working with biological +sequence data. In particular it implements MuE distributions, which are used as +a fully generative alternative to multiple sequence alignment-based +preprocessing. + +Reference: +MuE models were described in Weinstein and Marks (2021), +https://www.biorxiv.org/content/10.1101/2020.07.31.231381v2. + +Example MuE Models +------------------ +.. automodule:: pyro.contrib.mue.models + :members: + :show-inheritance: + :member-order: bysource + +State Arrangers for Parameterizing MuEs +--------------------------------------- +.. automodule:: pyro.contrib.mue.statearrangers + :members: + :show-inheritance: + :member-order: bysource + +Missing or Variable Length Data HMM +----------------------------------- +.. automodule:: pyro.contrib.mue.missingdatahmm + :members: + :show-inheritance: + :member-order: bysource + + +Biosequence Dataset Loading +--------------------------- +.. automodule:: pyro.contrib.mue.dataloaders + :members: + :show-inheritance: + :member-order: bysource diff --git a/docs/source/index.rst b/docs/source/index.rst index 0d364c67bf..c5db42fdc6 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -14,7 +14,7 @@ Pyro Documentation :caption: Pyro Core: getting_started - primitives + primitives inference distributions parameters @@ -39,6 +39,7 @@ Pyro Documentation contrib.funsor contrib.gp contrib.minipyro + contrib.mue contrib.oed contrib.randomvariable contrib.timeseries @@ -52,4 +53,3 @@ Indices and tables * :ref:`search` .. * :ref:`modindex` - diff --git a/examples/contrib/mue/FactorMuE.py b/examples/contrib/mue/FactorMuE.py new file mode 100644 index 0000000000..a122be7266 --- /dev/null +++ b/examples/contrib/mue/FactorMuE.py @@ -0,0 +1,288 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +""" +A probabilistic PCA model with a MuE observation, called a 'FactorMuE' model +[1]. This is a generative model of variable-length biological sequences (e.g. +proteins) which does not require preprocessing the data by building a +multiple sequence alignment. It can be used to infer a latent representation +of sequences and the principal components of sequence variation, while +accounting for alignment uncertainty. + +An example dataset consisting of proteins similar to the human papillomavirus E6 +protein, collected from a non-redundant sequence dataset using jackhmmer, can +be found at +https://github.com/debbiemarkslab/MuE/blob/master/models/examples/ve6_full.fasta + +Example run: +python FactorMuE.py -f PATH/ve6_full.fasta --z-dim 2 -b 10 -M 174 -D 25 + --indel-prior-bias 10. --anneal 5 -e 15 -lr 0.01 --z-prior Laplace + --jit --cuda +This should take about 8 minutes to run on a GPU. The latent space should show +multiple small clusters, and the perplexity should be around 4. + +Reference: +[1] E. N. Weinstein, D. S. Marks (2021) +"Generative probabilistic biological sequence models that account for +mutational variability" +https://www.biorxiv.org/content/10.1101/2020.07.31.231381v2.full.pdf +""" + +import argparse +import datetime +import json +import os + +import matplotlib.pyplot as plt +import numpy as np +import torch +from torch.optim import Adam + +import pyro +from pyro.contrib.mue.dataloaders import BiosequenceDataset +from pyro.contrib.mue.models import FactorMuE +from pyro.optim import MultiStepLR + + +def generate_data(small_test, include_stop, device): + """Generate mini example dataset.""" + if small_test: + mult_dat = 1 + else: + mult_dat = 10 + + seqs = ['BABBA']*mult_dat + ['BAAB']*mult_dat + ['BABBB']*mult_dat + dataset = BiosequenceDataset(seqs, 'list', 'AB', include_stop=include_stop, + device=device) + + return dataset + + +def main(args): + + # Load dataset. + if args.cpu_data and args.cuda: + device = torch.device('cpu') + else: + device = None + if args.test: + dataset = generate_data(args.small, args.include_stop, device) + else: + dataset = BiosequenceDataset(args.file, 'fasta', args.alphabet, + include_stop=args.include_stop, + device=device) + args.batch_size = min([dataset.data_size, args.batch_size]) + if args.split > 0.: + # Train test split. + heldout_num = int(np.ceil(args.split*len(dataset))) + data_lengths = [len(dataset) - heldout_num, heldout_num] + # Specific data split seed, for comparability across models and + # parameter initializations. + pyro.set_rng_seed(args.rng_data_seed) + indices = torch.randperm(sum(data_lengths)).tolist() + dataset_train, dataset_test = [ + torch.utils.data.Subset(dataset, indices[(offset - length):offset]) + for offset, length in zip(torch._utils._accumulate(data_lengths), + data_lengths)] + else: + dataset_train = dataset + dataset_test = None + + # Training seed. + pyro.set_rng_seed(args.rng_seed) + + # Construct model. + model = FactorMuE(dataset.max_length, dataset.alphabet_length, + args.z_dim, + batch_size=args.batch_size, + latent_seq_length=args.latent_seq_length, + indel_factor_dependence=args.indel_factor, + indel_prior_scale=args.indel_prior_scale, + indel_prior_bias=args.indel_prior_bias, + inverse_temp_prior=args.inverse_temp_prior, + weights_prior_scale=args.weights_prior_scale, + offset_prior_scale=args.offset_prior_scale, + z_prior_distribution=args.z_prior, + ARD_prior=args.ARD_prior, + substitution_matrix=(not args.no_substitution_matrix), + substitution_prior_scale=args.substitution_prior_scale, + latent_alphabet_length=args.latent_alphabet, + cuda=args.cuda, + pin_memory=args.pin_mem) + + # Infer with SVI. + scheduler = MultiStepLR({'optimizer': Adam, + 'optim_args': {'lr': args.learning_rate}, + 'milestones': json.loads(args.milestones), + 'gamma': args.learning_gamma}) + n_epochs = args.n_epochs + losses = model.fit_svi(dataset_train, n_epochs, args.anneal, + args.batch_size, scheduler, args.jit) + + # Evaluate. + train_lp, test_lp, train_perplex, test_perplex = model.evaluate( + dataset_train, dataset_test, args.jit) + print('train logp: {} perplex: {}'.format(train_lp, train_perplex)) + print('test logp: {} perplex: {}'.format(test_lp, test_perplex)) + + # Get latent space embedding. + z_locs, z_scales = model.embed(dataset) + + # Plot and save. + time_stamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + if not args.no_plots: + plt.figure(figsize=(6, 6)) + plt.plot(losses) + plt.xlabel('step') + plt.ylabel('loss') + if not args.no_save: + plt.savefig(os.path.join( + args.out_folder, + 'FactorMuE_plot.loss_{}.pdf'.format(time_stamp))) + + plt.figure(figsize=(6, 6)) + plt.scatter(z_locs[:, 0], z_locs[:, 1]) + plt.xlabel(r'$z_1$') + plt.ylabel(r'$z_2$') + if not args.no_save: + plt.savefig(os.path.join( + args.out_folder, + 'FactorMuE_plot.latent_{}.pdf'.format(time_stamp))) + + if not args.indel_factor: + # Plot indel parameters. See statearrangers.py for details on the + # r and u parameters. + plt.figure(figsize=(6, 6)) + insert = pyro.param("insert_q_mn").detach() + insert_expect = torch.exp(insert - insert.logsumexp(-1, True)) + plt.plot(insert_expect[:, :, 1].cpu().numpy()) + plt.xlabel('position') + plt.ylabel('probability of insert') + plt.legend([r'$r_0$', r'$r_1$', r'$r_2$']) + if not args.no_save: + plt.savefig(os.path.join( + args.out_folder, + 'FactorMuE_plot.insert_prob_{}.pdf'.format(time_stamp))) + plt.figure(figsize=(6, 6)) + delete = pyro.param("delete_q_mn").detach() + delete_expect = torch.exp(delete - delete.logsumexp(-1, True)) + plt.plot(delete_expect[:, :, 1].cpu().numpy()) + plt.xlabel('position') + plt.ylabel('probability of delete') + plt.legend([r'$u_0$', r'$u_1$', r'$u_2$']) + if not args.no_save: + plt.savefig(os.path.join( + args.out_folder, + 'FactorMuE_plot.delete_prob_{}.pdf'.format(time_stamp))) + + if not args.no_save: + pyro.get_param_store().save(os.path.join( + args.out_folder, + 'FactorMuE_results.params_{}.out'.format(time_stamp))) + with open(os.path.join( + args.out_folder, + 'FactorMuE_results.evaluation_{}.txt'.format(time_stamp)), + 'w') as ow: + ow.write('train_lp,test_lp,train_perplex,test_perplex\n') + ow.write('{},{},{},{}\n'.format(train_lp, test_lp, train_perplex, + test_perplex)) + np.savetxt(os.path.join( + args.out_folder, + 'FactorMuE_results.embed_loc_{}.txt'.format( + time_stamp)), + z_locs.cpu().numpy()) + np.savetxt(os.path.join( + args.out_folder, + 'FactorMuE_results.embed_scale_{}.txt'.format( + time_stamp)), + z_scales.cpu().numpy()) + with open(os.path.join( + args.out_folder, + 'FactorMuE_results.input_{}.txt'.format(time_stamp)), + 'w') as ow: + ow.write('[args]\n') + for elem in list(args.__dict__.keys()): + ow.write('{} = {}\n'.format(elem, args.__getattribute__(elem))) + + +if __name__ == '__main__': + # Parse command line arguments. + parser = argparse.ArgumentParser(description="Factor MuE model.") + parser.add_argument("--test", action='store_true', default=False, + help='Run with generated example dataset.') + parser.add_argument("--small", action='store_true', default=False, + help='Run with small example dataset.') + parser.add_argument("-r", "--rng-seed", default=0, type=int) + parser.add_argument("--rng-data-seed", default=0, type=int) + parser.add_argument("-f", "--file", default=None, type=str, + help='Input file (fasta format).') + parser.add_argument("-a", "--alphabet", default='amino-acid', + help='Alphabet (amino-acid OR dna OR ATGC ...).') + parser.add_argument("-zdim", "--z-dim", default=2, type=int, + help='z space dimension.') + parser.add_argument("-b", "--batch-size", default=10, type=int, + help='Batch size.') + parser.add_argument("-M", "--latent-seq-length", default=None, type=int, + help='Latent sequence length.') + parser.add_argument("-idfac", "--indel-factor", default=False, + action='store_true', + help='Indel parameters depend on latent variable.') + parser.add_argument("-zdist", "--z-prior", default='Normal', + help='Latent prior distribution (normal or Laplace).') + parser.add_argument("-ard", "--ARD-prior", default=False, + action='store_true', + help='Use automatic relevance detection prior.') + parser.add_argument("--no-substitution-matrix", default=False, + action='store_true', + help='Do not use substitution matrix.') + parser.add_argument("-D", "--latent-alphabet", default=None, type=int, + help='Latent alphabet length.') + parser.add_argument("--include-stop", default=False, action='store_true', + help='Include stop symbol at the end of each sequence.') + parser.add_argument("--indel-prior-scale", default=1., type=float, + help=('Indel prior scale parameter ' + + '(when indel-factor=False).')) + parser.add_argument("--indel-prior-bias", default=10., type=float, + help='Indel prior bias parameter.') + parser.add_argument("--inverse-temp-prior", default=100., type=float, + help='Inverse temperature prior mean.') + parser.add_argument("--weights-prior-scale", default=1., type=float, + help='Factor parameter prior scale.') + parser.add_argument("--offset-prior-scale", default=1., type=float, + help='Offset parameter prior scale.') + parser.add_argument("--substitution-prior-scale", default=10., type=float, + help='Substitution matrix prior scale.') + parser.add_argument("-lr", "--learning-rate", default=0.001, type=float, + help='Learning rate for Adam optimizer.') + parser.add_argument("--milestones", default='[]', type=str, + help='Milestones for multistage learning rate.') + parser.add_argument("--learning-gamma", default=0.5, type=float, + help='Gamma parameter for multistage learning rate.') + parser.add_argument("-e", "--n-epochs", default=10, type=int, + help='Number of epochs of training.') + parser.add_argument("--anneal", default=0., type=float, + help='Number of epochs to anneal beta over.') + parser.add_argument("--no-plots", default=False, action='store_true', + help='Make plots.') + parser.add_argument("--no-save", default=False, action='store_true', + help='Do not save plots and results.') + parser.add_argument("-outf", "--out-folder", default='.', + help='Folder to save plots.') + parser.add_argument("--split", default=0.2, type=float, + help=('Fraction of dataset to holdout for testing')) + parser.add_argument("--jit", default=False, action='store_true', + help='JIT compile the ELBO.') + parser.add_argument("--cuda", default=False, action='store_true', + help='Use GPU.') + parser.add_argument("--cpu-data", default=False, action='store_true', + help='Keep data on CPU (for large datasets).') + parser.add_argument("--pin-mem", default=False, action='store_true', + help='Use pin_memory for faster CPU to GPU transfer.') + args = parser.parse_args() + + if args.cuda: + torch.set_default_tensor_type(torch.cuda.DoubleTensor) + else: + torch.set_default_dtype(torch.float64) + + main(args) diff --git a/examples/contrib/mue/ProfileHMM.py b/examples/contrib/mue/ProfileHMM.py new file mode 100644 index 0000000000..61df67f039 --- /dev/null +++ b/examples/contrib/mue/ProfileHMM.py @@ -0,0 +1,231 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +""" +A standard profile HMM model [1], which corresponds to a constant (delta +function) distribution with a MuE observation [2]. This is a standard +generative model of variable-length biological sequences (e.g. proteins) which +does not require preprocessing the data by building a multiple sequence +alignment. It can be compared to a more complex MuE model in this package, +the FactorMuE. + +An example dataset consisting of proteins similar to the human papillomavirus E6 +protein, collected from a non-redundant sequence dataset using jackhmmer, can +be found at +https://github.com/debbiemarkslab/MuE/blob/master/models/examples/ve6_full.fasta + +Example run: +python ProfileHMM.py -f PATH/ve6_full.fasta -b 10 -M 174 --indel-prior-bias 10. + -e 15 -lr 0.01 --jit --cuda +This should take about 9 minutes to run on a GPU. The perplexity should be +around 6. + +References: +[1] R. Durbin, S. R. Eddy, A. Krogh, and G. Mitchison (1998) +"Biological sequence analysis: probabilistic models of proteins and nucleic +acids" +Cambridge university press + +[2] E. N. Weinstein, D. S. Marks (2021) +"Generative probabilistic biological sequence models that account for +mutational variability" +https://www.biorxiv.org/content/10.1101/2020.07.31.231381v2.full.pdf +""" + +import argparse +import datetime +import json +import os + +import matplotlib.pyplot as plt +import numpy as np +import torch +from torch.optim import Adam + +import pyro +from pyro.contrib.mue.dataloaders import BiosequenceDataset +from pyro.contrib.mue.models import ProfileHMM +from pyro.optim import MultiStepLR + + +def generate_data(small_test, include_stop, device): + """Generate mini example dataset.""" + if small_test: + mult_dat = 1 + else: + mult_dat = 10 + + seqs = ['BABBA']*mult_dat + ['BAAB']*mult_dat + ['BABBB']*mult_dat + dataset = BiosequenceDataset(seqs, 'list', 'AB', include_stop=include_stop, + device=device) + + return dataset + + +def main(args): + + pyro.set_rng_seed(args.rng_seed) + + # Load dataset. + if args.cpu_data and args.cuda: + device = torch.device('cpu') + else: + device = None + if args.test: + dataset = generate_data(args.small, args.include_stop, device) + else: + dataset = BiosequenceDataset(args.file, 'fasta', args.alphabet, + include_stop=args.include_stop, + device=device) + args.batch_size = min([dataset.data_size, args.batch_size]) + if args.split > 0.: + # Train test split. + heldout_num = int(np.ceil(args.split*len(dataset))) + data_lengths = [len(dataset) - heldout_num, heldout_num] + # Specific data split seed, for comparability across models and + # parameter initializations. + pyro.set_rng_seed(args.rng_data_seed) + indices = torch.randperm(sum(data_lengths)).tolist() + dataset_train, dataset_test = [ + torch.utils.data.Subset(dataset, indices[(offset - length):offset]) + for offset, length in zip(torch._utils._accumulate(data_lengths), + data_lengths)] + else: + dataset_train = dataset + dataset_test = None + + # Construct model. + latent_seq_length = args.latent_seq_length + if latent_seq_length is None: + latent_seq_length = int(dataset.max_length * 1.1) + model = ProfileHMM(latent_seq_length, dataset.alphabet_length, + prior_scale=args.prior_scale, + indel_prior_bias=args.indel_prior_bias, + cuda=args.cuda, + pin_memory=args.pin_mem) + + # Infer with SVI. + scheduler = MultiStepLR({'optimizer': Adam, + 'optim_args': {'lr': args.learning_rate}, + 'milestones': json.loads(args.milestones), + 'gamma': args.learning_gamma}) + n_epochs = args.n_epochs + losses = model.fit_svi(dataset, n_epochs, args.batch_size, scheduler, + args.jit) + + # Evaluate. + train_lp, test_lp, train_perplex, test_perplex = model.evaluate( + dataset_train, dataset_test, args.jit) + print('train logp: {} perplex: {}'.format(train_lp, train_perplex)) + print('test logp: {} perplex: {}'.format(test_lp, test_perplex)) + + # Plots. + time_stamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + if not args.no_plots: + plt.figure(figsize=(6, 6)) + plt.plot(losses) + plt.xlabel('step') + plt.ylabel('loss') + if not args.no_save: + plt.savefig(os.path.join( + args.out_folder, + 'ProfileHMM_plot.loss_{}.pdf'.format(time_stamp))) + + plt.figure(figsize=(6, 6)) + insert = pyro.param("insert_q_mn").detach() + insert_expect = torch.exp(insert - insert.logsumexp(-1, True)) + plt.plot(insert_expect[:, :, 1].cpu().numpy()) + plt.xlabel('position') + plt.ylabel('probability of insert') + plt.legend([r'$r_0$', r'$r_1$', r'$r_2$']) + if not args.no_save: + plt.savefig(os.path.join( + args.out_folder, + 'ProfileHMM_plot.insert_prob_{}.pdf'.format(time_stamp))) + plt.figure(figsize=(6, 6)) + delete = pyro.param("delete_q_mn").detach() + delete_expect = torch.exp(delete - delete.logsumexp(-1, True)) + plt.plot(delete_expect[:, :, 1].cpu().numpy()) + plt.xlabel('position') + plt.ylabel('probability of delete') + plt.legend([r'$u_0$', r'$u_1$', r'$u_2$']) + if not args.no_save: + plt.savefig(os.path.join( + args.out_folder, + 'ProfileHMM_plot.delete_prob_{}.pdf'.format(time_stamp))) + + if not args.no_save: + pyro.get_param_store().save(os.path.join( + args.out_folder, + 'ProfileHMM_results.params_{}.out'.format(time_stamp))) + with open(os.path.join( + args.out_folder, + 'ProfileHMM_results.evaluation_{}.txt'.format(time_stamp)), + 'w') as ow: + ow.write('train_lp,test_lp,train_perplex,test_perplex\n') + ow.write('{},{},{},{}\n'.format(train_lp, test_lp, train_perplex, + test_perplex)) + with open(os.path.join( + args.out_folder, + 'ProfileHMM_results.input_{}.txt'.format(time_stamp)), + 'w') as ow: + ow.write('[args]\n') + for elem in list(args.__dict__.keys()): + ow.write('{} = {}\n'.format(elem, args.__getattribute__(elem))) + + +if __name__ == '__main__': + # Parse command line arguments. + parser = argparse.ArgumentParser(description="Profile HMM model.") + parser.add_argument("--test", action='store_true', default=False, + help='Run with generated example dataset.') + parser.add_argument("--small", action='store_true', default=False, + help='Run with small example dataset.') + parser.add_argument("-r", "--rng-seed", default=0, type=int) + parser.add_argument("--rng-data-seed", default=0, type=int) + parser.add_argument("-f", "--file", default=None, type=str, + help='Input file (fasta format).') + parser.add_argument("-a", "--alphabet", default='amino-acid', + help='Alphabet (amino-acid OR dna OR ATGC ...).') + parser.add_argument("-b", "--batch-size", default=10, type=int, + help='Batch size.') + parser.add_argument("-M", "--latent-seq-length", default=None, type=int, + help='Latent sequence length.') + parser.add_argument("--include-stop", default=False, action='store_true', + help='Include stop symbol at the end of each sequence.') + parser.add_argument("--prior-scale", default=1., type=float, + help='Prior scale parameter (all parameters).') + parser.add_argument("--indel-prior-bias", default=10., type=float, + help='Indel prior bias parameter.') + parser.add_argument("-lr", "--learning-rate", default=0.001, type=float, + help='Learning rate for Adam optimizer.') + parser.add_argument("--milestones", default='[]', type=str, + help='Milestones for multistage learning rate.') + parser.add_argument("--learning-gamma", default=0.5, type=float, + help='Gamma parameter for multistage learning rate.') + parser.add_argument("-e", "--n-epochs", default=10, type=int, + help='Number of epochs of training.') + parser.add_argument("--no-plots", default=False, action='store_true', + help='Make plots.') + parser.add_argument("--no-save", default=False, action='store_true', + help='Do not save plots and results.') + parser.add_argument("-outf", "--out-folder", default='.', + help='Folder to save plots.') + parser.add_argument("--split", default=0.2, type=float, + help=('Fraction of dataset to holdout for testing')) + parser.add_argument("--jit", default=False, action='store_true', + help='JIT compile the ELBO.') + parser.add_argument("--cuda", default=False, action='store_true', + help='Use GPU.') + parser.add_argument("--cpu-data", default=False, action='store_true', + help='Keep data on CPU (for large datasets).') + parser.add_argument("--pin-mem", default=False, action='store_true', + help='Use pin_memory for faster GPU transfer.') + args = parser.parse_args() + + if args.cuda: + torch.set_default_tensor_type(torch.cuda.DoubleTensor) + else: + torch.set_default_dtype(torch.float64) + + main(args) diff --git a/pyro/contrib/mue/__init__.py b/pyro/contrib/mue/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/pyro/contrib/mue/dataloaders.py b/pyro/contrib/mue/dataloaders.py new file mode 100644 index 0000000000..b6fbbb4489 --- /dev/null +++ b/pyro/contrib/mue/dataloaders.py @@ -0,0 +1,112 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import torch +from torch.utils.data import Dataset + +alphabets = {'amino-acid': np.array( + ['R', 'H', 'K', 'D', 'E', + 'S', 'T', 'N', 'Q', 'C', + 'G', 'P', 'A', 'V', 'I', + 'L', 'M', 'F', 'Y', 'W']), + 'dna': np.array(['A', 'C', 'G', 'T'])} + + +class BiosequenceDataset(Dataset): + """ + Load biological sequence data, either from a fasta file or a python list. + + :param source: Either the input fasta file path (str) or the input list + of sequences (list of str). + :param str source_type: Type of input, either 'list' or 'fasta'. + :param str alphabet: Alphabet to use. Alphabets 'amino-acid' and 'dna' are + preset; any other input will be interpreted as the alphabet itself, + i.e. you can use 'ACGU' for RNA. + :param int max_length: Total length of the one-hot representation of the + sequences, including zero padding. Defaults to the maximum sequence + length in the dataset. + :param bool include_stop: Append stop symbol to the end of each sequence + and add the stop symbol to the alphabet. + :param torch.device device: Device on which data should be stored in + memory. + """ + + def __init__(self, source, source_type='list', alphabet='amino-acid', + max_length=None, include_stop=False, device=None): + + super().__init__() + + # Determine device + if device is None: + device = torch.tensor(0.).device + self.device = device + + # Get sequences. + self.include_stop = include_stop + if source_type == 'list': + seqs = [seq + include_stop*'*' for seq in source] + elif source_type == 'fasta': + seqs = self._load_fasta(source) + + # Get lengths. + self.L_data = torch.tensor([float(len(seq)) for seq in seqs], + device=device) + if max_length is None: + self.max_length = int(torch.max(self.L_data)) + else: + self.max_length = max_length + self.data_size = len(self.L_data) + + # Get alphabet. + if alphabet in alphabets: + alphabet = alphabets[alphabet] + else: + alphabet = np.array(list(alphabet)) + if self.include_stop: + alphabet = np.array(list(alphabet) + ['*']) + self.alphabet = alphabet + self.alphabet_length = len(alphabet) + + # Build dataset. + self.seq_data = torch.cat([self._one_hot( + seq, alphabet, self.max_length).unsqueeze(0) for seq in seqs]) + + def _load_fasta(self, source): + """A basic multiline fasta parser.""" + seqs = [] + seq = '' + with open(source, 'r') as fr: + for line in fr: + if line[0] == '>': + if seq != '': + if self.include_stop: + seq += '*' + seqs.append(seq) + seq = '' + else: + seq += line.strip('\n') + if seq != '': + if self.include_stop: + seq += '*' + seqs.append(seq) + return seqs + + def _one_hot(self, seq, alphabet, length): + """One hot encode and pad with zeros to max length.""" + # One hot encode. + oh = torch.tensor((np.array(list(seq))[:, None] == alphabet[None, :] + ).astype(np.float64), device=self.device) + # Pad. + x = torch.cat([oh, torch.zeros([length - len(seq), len(alphabet)], + device=self.device)]) + + return x + + def __len__(self): + + return self.data_size + + def __getitem__(self, ind): + + return (self.seq_data[ind], self.L_data[ind]) diff --git a/pyro/contrib/mue/missingdatahmm.py b/pyro/contrib/mue/missingdatahmm.py new file mode 100644 index 0000000000..eb414bf82c --- /dev/null +++ b/pyro/contrib/mue/missingdatahmm.py @@ -0,0 +1,104 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import torch + +from pyro.distributions import constraints +from pyro.distributions.hmm import _sequential_logmatmulexp +from pyro.distributions.torch_distribution import TorchDistribution +from pyro.distributions.util import broadcast_shape + + +class MissingDataDiscreteHMM(TorchDistribution): + """ + HMM with discrete latent states and discrete observations, allowing for + missing data or variable length sequences. Observations are assumed + to be one hot encoded; rows with all zeros indicate missing data. + + .. warning:: Unlike in pyro's pyro.distributions.DiscreteHMM, which + computes the probability of the first state as + initial.T @ transition @ emission + this distribution uses the standard HMM convention, + initial.T @ emission + + :param ~torch.Tensor initial_logits: A logits tensor for an initial + categorical distribution over latent states. Should have rightmost + size ``state_dim`` and be broadcastable to + ``(batch_size, state_dim)``. + :param ~torch.Tensor transition_logits: A logits tensor for transition + conditional distributions between latent states. Should have rightmost + shape ``(state_dim, state_dim)`` (old, new), and be broadcastable + to ``(batch_size, state_dim, state_dim)``. + :param ~torch.Tensor observation_logits: A logits tensor for observation + distributions from latent states. Should have rightmost shape + ``(state_dim, categorical_size)``, where ``categorical_size`` is the + dimension of the categorical output, and be broadcastable + to ``(batch_size, state_dim, categorical_size)``. + """ + arg_constraints = {"initial_logits": constraints.real_vector, + "transition_logits": constraints.independent( + constraints.real, 2), + "observation_logits": constraints.independent( + constraints.real, 2)} + support = constraints.independent(constraints.nonnegative_integer, 2) + + def __init__(self, initial_logits, transition_logits, observation_logits, + validate_args=None): + if initial_logits.dim() < 1: + raise ValueError( + "expected initial_logits to have at least one dim, " + "actual shape = {}".format(initial_logits.shape)) + if transition_logits.dim() < 2: + raise ValueError( + "expected transition_logits to have at least two dims, " + "actual shape = {}".format(transition_logits.shape)) + if observation_logits.dim() < 2: + raise ValueError( + "expected observation_logits to have at least two dims, " + "actual shape = {}".format(transition_logits.shape)) + shape = broadcast_shape(initial_logits.shape[:-1], + transition_logits.shape[:-2], + observation_logits.shape[:-2]) + if len(shape) == 0: + shape = torch.Size([1]) + batch_shape = shape + event_shape = (1, observation_logits.shape[-1]) + self.initial_logits = (initial_logits - + initial_logits.logsumexp(-1, True)) + self.transition_logits = (transition_logits - + transition_logits.logsumexp(-1, True)) + self.observation_logits = (observation_logits - + observation_logits.logsumexp(-1, True)) + super(MissingDataDiscreteHMM, self).__init__( + batch_shape, event_shape, validate_args=validate_args) + + def log_prob(self, value): + """ + :param ~torch.Tensor value: One-hot encoded observation. Must be + real-valued (float) and broadcastable to + ``(batch_size, num_steps, categorical_size)`` where + ``categorical_size`` is the dimension of the categorical output. + Missing data is represented by zeros, i.e. + ``value[batch, step, :] == tensor([0, ..., 0])``. + Variable length observation sequences can be handled by padding + the sequence with zeros at the end. + """ + + assert value.shape[-1] == self.event_shape[1] + + # Combine observation and transition factors. + value_logits = torch.matmul( + value, torch.transpose(self.observation_logits, -2, -1)) + result = (self.transition_logits.unsqueeze(-3) + + value_logits[..., 1:, None, :]) + + # Eliminate time dimension. + result = _sequential_logmatmulexp(result) + + # Combine initial factor. + result = (self.initial_logits + value_logits[..., 0, :] + + result.logsumexp(-1)) + + # Marginalize out final state. + result = result.logsumexp(-1) + return result diff --git a/pyro/contrib/mue/models.py b/pyro/contrib/mue/models.py new file mode 100644 index 0000000000..fb55a2fa9f --- /dev/null +++ b/pyro/contrib/mue/models.py @@ -0,0 +1,731 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +""" +Example MuE observation models. +""" + +import datetime + +import numpy as np +import torch +import torch.nn as nn +from torch.nn.functional import softplus +from torch.optim import Adam +from torch.utils.data import DataLoader + +import pyro +import pyro.distributions as dist +from pyro import poutine +from pyro.contrib.mue.missingdatahmm import MissingDataDiscreteHMM +from pyro.contrib.mue.statearrangers import Profile +from pyro.infer import SVI, JitTrace_ELBO, Trace_ELBO +from pyro.optim import MultiStepLR + + +class ProfileHMM(nn.Module): + """ + Profile HMM. + + This model consists of a constant distribution (a delta function) over the + regressor sequence, plus a MuE observation distribution. The priors + are all Normal distributions, and are pushed through a softmax function + onto the simplex. + + :param int latent_seq_length: Length of the latent regressor sequence M. + Must be greater than or equal to 1. + :param int alphabet_length: Length of the sequence alphabet (e.g. 20 for + amino acids). + :param float prior_scale: Standard deviation of the prior distribution. + :param float indel_prior_bias: Mean of the prior distribution over the + log probability of an indel not occurring. Higher values lead to lower + probability of indels. + :param bool cuda: Transfer data onto the GPU during training. + :param bool pin_memory: Pin memory for faster GPU transfer. + """ + def __init__(self, latent_seq_length, alphabet_length, + prior_scale=1., indel_prior_bias=10., + cuda=False, pin_memory=False): + super().__init__() + assert isinstance(cuda, bool) + self.is_cuda = cuda + assert isinstance(pin_memory, bool) + self.pin_memory = pin_memory + + assert isinstance(latent_seq_length, int) and latent_seq_length > 0 + self.latent_seq_length = latent_seq_length + assert isinstance(alphabet_length, int) and alphabet_length > 0 + self.alphabet_length = alphabet_length + + self.precursor_seq_shape = (latent_seq_length, alphabet_length) + self.insert_seq_shape = (latent_seq_length+1, alphabet_length) + self.indel_shape = (latent_seq_length, 3, 2) + + assert isinstance(prior_scale, float) + self.prior_scale = prior_scale + assert isinstance(indel_prior_bias, float) + self.indel_prior = torch.tensor([indel_prior_bias, 0.]) + + # Initialize state arranger. + self.statearrange = Profile(latent_seq_length) + + def model(self, seq_data, local_scale): + + # Latent sequence. + precursor_seq = pyro.sample("precursor_seq", dist.Normal( + torch.zeros(self.precursor_seq_shape), + self.prior_scale * + torch.ones(self.precursor_seq_shape)).to_event(2)) + precursor_seq_logits = precursor_seq - precursor_seq.logsumexp(-1, True) + insert_seq = pyro.sample("insert_seq", dist.Normal( + torch.zeros(self.insert_seq_shape), + self.prior_scale * + torch.ones(self.insert_seq_shape)).to_event(2)) + insert_seq_logits = insert_seq - insert_seq.logsumexp(-1, True) + + # Indel probabilities. + insert = pyro.sample("insert", dist.Normal( + self.indel_prior * torch.ones(self.indel_shape), + self.prior_scale * torch.ones(self.indel_shape)).to_event(3)) + insert_logits = insert - insert.logsumexp(-1, True) + delete = pyro.sample("delete", dist.Normal( + self.indel_prior * torch.ones(self.indel_shape), + self.prior_scale * torch.ones(self.indel_shape)).to_event(3)) + delete_logits = delete - delete.logsumexp(-1, True) + + # Construct HMM parameters. + initial_logits, transition_logits, observation_logits = ( + self.statearrange(precursor_seq_logits, insert_seq_logits, + insert_logits, delete_logits)) + + with pyro.plate("batch", seq_data.shape[0]): + with poutine.scale(scale=local_scale): + # Observations. + pyro.sample("obs_seq", + MissingDataDiscreteHMM(initial_logits, + transition_logits, + observation_logits), + obs=seq_data) + + def guide(self, seq_data, local_scale): + # Sequence. + precursor_seq_q_mn = pyro.param("precursor_seq_q_mn", + torch.zeros(self.precursor_seq_shape)) + precursor_seq_q_sd = pyro.param("precursor_seq_q_sd", + torch.zeros(self.precursor_seq_shape)) + pyro.sample("precursor_seq", dist.Normal( + precursor_seq_q_mn, softplus(precursor_seq_q_sd)).to_event(2)) + insert_seq_q_mn = pyro.param("insert_seq_q_mn", + torch.zeros(self.insert_seq_shape)) + insert_seq_q_sd = pyro.param("insert_seq_q_sd", + torch.zeros(self.insert_seq_shape)) + pyro.sample("insert_seq", dist.Normal( + insert_seq_q_mn, softplus(insert_seq_q_sd)).to_event(2)) + + # Indels. + insert_q_mn = pyro.param("insert_q_mn", + torch.ones(self.indel_shape) + * self.indel_prior) + insert_q_sd = pyro.param("insert_q_sd", + torch.zeros(self.indel_shape)) + pyro.sample("insert", dist.Normal( + insert_q_mn, softplus(insert_q_sd)).to_event(3)) + delete_q_mn = pyro.param("delete_q_mn", + torch.ones(self.indel_shape) + * self.indel_prior) + delete_q_sd = pyro.param("delete_q_sd", + torch.zeros(self.indel_shape)) + pyro.sample("delete", dist.Normal( + delete_q_mn, softplus(delete_q_sd)).to_event(3)) + + def fit_svi(self, dataset, epochs=2, batch_size=1, scheduler=None, + jit=False): + """ + Infer approximate posterior with stochastic variational inference. + + This runs :class:`~pyro.infer.svi.SVI`. It is an approximate inference + method useful for quickly iterating on probabilistic models. + + :param ~torch.utils.data.Dataset dataset: The training dataset. + :param int epochs: Number of epochs of training. + :param int batch_size: Minibatch size (number of sequences). + :param pyro.optim.MultiStepLR scheduler: Optimization scheduler. + (Default: Adam optimizer, 0.01 constant learning rate.) + :param bool jit: Whether to use a jit compiled ELBO. + """ + + # Setup. + if batch_size is not None: + self.batch_size = batch_size + if scheduler is None: + scheduler = MultiStepLR({'optimizer': Adam, + 'optim_args': {'lr': 0.01}, + 'milestones': [], + 'gamma': 0.5}) + # Initialize guide. + self.guide(None, None) + dataload = DataLoader(dataset, batch_size=batch_size, shuffle=True, + pin_memory=self.pin_memory) + # Setup stochastic variational inference. + if jit: + elbo = JitTrace_ELBO(ignore_jit_warnings=True) + else: + elbo = Trace_ELBO() + svi = SVI(self.model, self.guide, scheduler, loss=elbo) + + # Run inference. + losses = [] + t0 = datetime.datetime.now() + for epoch in range(epochs): + for seq_data, L_data in dataload: + if self.is_cuda: + seq_data = seq_data.cuda() + loss = svi.step(seq_data, + torch.tensor(len(dataset)/seq_data.shape[0])) + losses.append(loss) + scheduler.step() + print(epoch, loss, ' ', datetime.datetime.now() - t0) + return losses + + def evaluate(self, dataset_train, dataset_test=None, jit=False): + """ + Evaluate performance (log probability and per residue perplexity) on + train and test datasets. + + :param ~torch.utils.data.Dataset dataset: The training dataset. + :param ~torch.utils.data.Dataset dataset: The testing dataset. + :param bool jit: Whether to use a jit compiled ELBO. + """ + dataload_train = DataLoader(dataset_train, batch_size=1, shuffle=False) + if dataset_test is not None: + dataload_test = DataLoader(dataset_test, batch_size=1, + shuffle=False) + # Initialize guide. + self.guide(None, None) + if jit: + elbo = JitTrace_ELBO(ignore_jit_warnings=True) + else: + elbo = Trace_ELBO() + scheduler = MultiStepLR({'optimizer': Adam, 'optim_args': {'lr': 0.01}}) + # Setup stochastic variational inference. + svi = SVI(self.model, self.guide, scheduler, loss=elbo) + + # Compute elbo and perplexity. + train_lp, train_perplex = self._evaluate_local_elbo( + svi, dataload_train, len(dataset_train)) + if dataset_test is not None: + test_lp, test_perplex = self._evaluate_local_elbo( + svi, dataload_test, len(dataset_test)) + return train_lp, test_lp, train_perplex, test_perplex + else: + return train_lp, None, train_perplex, None + + def _local_variables(self, name, site): + """Return per datapoint random variables in model.""" + return name in ['obs_L', 'obs_seq'] + + def _evaluate_local_elbo(self, svi, dataload, data_size): + """Evaluate elbo and average per residue perplexity.""" + lp, perplex = 0., 0. + with torch.no_grad(): + for seq_data, L_data in dataload: + if self.is_cuda: + seq_data, L_data = seq_data.cuda(), L_data.cuda() + conditioned_model = poutine.condition(self.model, data={ + "obs_seq": seq_data}) + args = (seq_data, torch.tensor(1.)) + guide_tr = poutine.trace(self.guide).get_trace(*args) + model_tr = poutine.trace(poutine.replay( + conditioned_model, trace=guide_tr)).get_trace(*args) + local_elbo = (model_tr.log_prob_sum(self._local_variables) + - guide_tr.log_prob_sum(self._local_variables) + ).cpu().numpy() + lp += local_elbo + perplex += -local_elbo / L_data[0].cpu().numpy() + perplex = np.exp(perplex / data_size) + return lp, perplex + + +class Encoder(nn.Module): + def __init__(self, data_length, alphabet_length, z_dim): + super().__init__() + + self.input_size = data_length * alphabet_length + self.f1_mn = nn.Linear(self.input_size, z_dim) + self.f1_sd = nn.Linear(self.input_size, z_dim) + + def forward(self, data): + + data = data.reshape(-1, self.input_size) + z_loc = self.f1_mn(data) + z_scale = softplus(self.f1_sd(data)) + + return z_loc, z_scale + + +class FactorMuE(nn.Module): + """ + FactorMuE + + This model consists of probabilistic PCA plus a MuE output distribution. + + The priors are all Normal distributions, and where relevant pushed through + a softmax onto the simplex. + + :param int data_length: Length of the input sequence matrix, including + zero padding at the end. + :param int alphabet_length: Length of the sequence alphabet (e.g. 20 for + amino acids). + :param int z_dim: Number of dimensions of the z space. + :param int batch_size: Minibatch size. + :param int latent_seq_length: Length of the latent regressor sequence (M). + Must be greater than or equal to 1. (Default: 1.1 x data_length.) + :param bool indel_factor_dependence: Indel probabilities depend on the + latent variable z. + :param float indel_prior_scale: Standard deviation of the prior + distribution on indel parameters. + :param float indel_prior_bias: Mean of the prior distribution over the + log probability of an indel not occurring. Higher values lead to lower + probability of indels. + :param float inverse_temp_prior: Mean of the prior distribution over the + inverse temperature parameter. + :param float weights_prior_scale: Standard deviation of the prior + distribution over the factors. + :param float offset_prior_scale: Standard deviation of the prior + distribution over the offset (constant) in the pPCA model. + :param str z_prior_distribution: Prior distribution over the latent + variable z. Either 'Normal' (pPCA model) or 'Laplace' (an ICA model). + :param bool ARD_prior: Use automatic relevance determination prior on + factors. + :param bool substitution_matrix: Use a learnable substitution matrix + rather than the identity matrix. + :param float substitution_prior_scale: Standard deviation of the prior + distribution over substitution matrix parameters (when + substitution_matrix is True). + :param int latent_alphabet_length: Length of the alphabet in the latent + regressor sequence. + :param bool cuda: Transfer data onto the GPU during training. + :param bool pin_memory: Pin memory for faster GPU transfer. + :param float epsilon: A small value for numerical stability. + """ + def __init__(self, data_length, alphabet_length, z_dim, + batch_size=10, + latent_seq_length=None, + indel_factor_dependence=False, + indel_prior_scale=1., + indel_prior_bias=10., + inverse_temp_prior=100., + weights_prior_scale=1., + offset_prior_scale=1., + z_prior_distribution='Normal', + ARD_prior=False, + substitution_matrix=True, + substitution_prior_scale=10., + latent_alphabet_length=None, + cuda=False, + pin_memory=False, + epsilon=1e-32): + super().__init__() + assert isinstance(cuda, bool) + self.is_cuda = cuda + assert isinstance(pin_memory, bool) + self.pin_memory = pin_memory + + # Constants. + assert isinstance(data_length, int) and data_length > 0 + self.data_length = data_length + if latent_seq_length is None: + latent_seq_length = int(data_length * 1.1) + else: + assert isinstance(latent_seq_length, int) and latent_seq_length > 0 + self.latent_seq_length = latent_seq_length + assert isinstance(alphabet_length, int) and alphabet_length > 0 + self.alphabet_length = alphabet_length + assert isinstance(z_dim, int) and z_dim > 0 + self.z_dim = z_dim + + # Parameter shapes. + if (not substitution_matrix) or (latent_alphabet_length is None): + latent_alphabet_length = alphabet_length + self.latent_alphabet_length = latent_alphabet_length + self.indel_shape = (latent_seq_length, 3, 2) + self.total_factor_size = ( + (2*latent_seq_length+1)*latent_alphabet_length + + 2*indel_factor_dependence*latent_seq_length*3*2) + + # Architecture. + self.indel_factor_dependence = indel_factor_dependence + self.ARD_prior = ARD_prior + self.substitution_matrix = substitution_matrix + + # Priors. + assert isinstance(indel_prior_scale, float) + self.indel_prior_scale = torch.tensor(indel_prior_scale) + assert isinstance(indel_prior_bias, float) + self.indel_prior = torch.tensor([indel_prior_bias, 0.]) + assert isinstance(inverse_temp_prior, float) + self.inverse_temp_prior = torch.tensor(inverse_temp_prior) + assert isinstance(weights_prior_scale, float) + self.weights_prior_scale = torch.tensor(weights_prior_scale) + assert isinstance(offset_prior_scale, float) + self.offset_prior_scale = torch.tensor(offset_prior_scale) + assert isinstance(epsilon, float) + self.epsilon = torch.tensor(epsilon) + assert isinstance(substitution_prior_scale, float) + self.substitution_prior_scale = torch.tensor(substitution_prior_scale) + self.z_prior_distribution = z_prior_distribution + + # Batch control. + assert isinstance(batch_size, int) + self.batch_size = batch_size + + # Initialize layers. + self.encoder = Encoder(data_length, alphabet_length, z_dim) + self.statearrange = Profile(latent_seq_length) + + def decoder(self, z, W, B, inverse_temp): + + # Project. + v = torch.mm(z, W) + B + + out = dict() + if self.indel_factor_dependence: + # Extract insertion and deletion parameters. + ind0 = (2*self.latent_seq_length+1)*self.latent_alphabet_length + ind1 = ind0 + self.latent_seq_length*3*2 + ind2 = ind1 + self.latent_seq_length*3*2 + insert_v, delete_v = v[:, ind0:ind1], v[:, ind1:ind2] + insert_v = (insert_v.reshape([-1, self.latent_seq_length, 3, 2]) + + self.indel_prior) + out['insert_logits'] = insert_v - insert_v.logsumexp(-1, True) + delete_v = (delete_v.reshape([-1, self.latent_seq_length, 3, 2]) + + self.indel_prior) + out['delete_logits'] = delete_v - delete_v.logsumexp(-1, True) + # Extract precursor and insertion sequences. + ind0 = self.latent_seq_length*self.latent_alphabet_length + ind1 = ind0 + (self.latent_seq_length+1)*self.latent_alphabet_length + precursor_seq_v, insert_seq_v = v[:, :ind0], v[:, ind0:ind1] + precursor_seq_v = (precursor_seq_v*softplus(inverse_temp)).reshape([ + -1, self.latent_seq_length, self.latent_alphabet_length]) + out['precursor_seq_logits'] = ( + precursor_seq_v - precursor_seq_v.logsumexp(-1, True)) + insert_seq_v = (insert_seq_v*softplus(inverse_temp)).reshape([ + -1, self.latent_seq_length+1, self.latent_alphabet_length]) + out['insert_seq_logits'] = ( + insert_seq_v - insert_seq_v.logsumexp(-1, True)) + + return out + + def model(self, seq_data, local_scale, local_prior_scale): + + # ARD prior. + if self.ARD_prior: + # Relevance factors + alpha = pyro.sample("alpha", dist.Gamma( + torch.ones(self.z_dim), torch.ones(self.z_dim)).to_event(1)) + else: + alpha = torch.ones(self.z_dim) + + # Factor and offset. + W = pyro.sample("W", dist.Normal( + torch.zeros([self.z_dim, self.total_factor_size]), + torch.ones([self.z_dim, self.total_factor_size]) * + self.weights_prior_scale / (alpha[:, None] + self.epsilon) + ).to_event(2)) + B = pyro.sample("B", dist.Normal( + torch.zeros(self.total_factor_size), + torch.ones(self.total_factor_size) * self.offset_prior_scale + ).to_event(1)) + + # Indel probabilities. + if not self.indel_factor_dependence: + insert = pyro.sample("insert", dist.Normal( + self.indel_prior * torch.ones(self.indel_shape), + self.indel_prior_scale * torch.ones(self.indel_shape) + ).to_event(3)) + insert_logits = insert - insert.logsumexp(-1, True) + delete = pyro.sample("delete", dist.Normal( + self.indel_prior * torch.ones(self.indel_shape), + self.indel_prior_scale * torch.ones(self.indel_shape) + ).to_event(3)) + delete_logits = delete - delete.logsumexp(-1, True) + + # Inverse temperature. + inverse_temp = pyro.sample("inverse_temp", dist.Normal( + self.inverse_temp_prior, torch.tensor(1.))) + + # Substitution matrix. + if self.substitution_matrix: + substitute = pyro.sample("substitute", dist.Normal( + torch.zeros([ + self.latent_alphabet_length, self.alphabet_length]), + self.substitution_prior_scale * torch.ones([ + self.latent_alphabet_length, self.alphabet_length]) + ).to_event(2)) + + with pyro.plate("batch", seq_data.shape[0]): + with poutine.scale(scale=local_scale): + with poutine.scale(scale=local_prior_scale): + # Sample latent variable from prior. + if self.z_prior_distribution == 'Normal': + z = pyro.sample("latent", dist.Normal( + torch.zeros(self.z_dim), torch.ones(self.z_dim) + ).to_event(1)) + elif self.z_prior_distribution == 'Laplace': + z = pyro.sample("latent", dist.Laplace( + torch.zeros(self.z_dim), torch.ones(self.z_dim) + ).to_event(1)) + + # Decode latent sequence. + decoded = self.decoder(z, W, B, inverse_temp) + if self.indel_factor_dependence: + insert_logits = decoded['insert_logits'] + delete_logits = decoded['delete_logits'] + + # Construct HMM parameters. + if self.substitution_matrix: + initial_logits, transition_logits, observation_logits = ( + self.statearrange(decoded['precursor_seq_logits'], + decoded['insert_seq_logits'], + insert_logits, delete_logits, + substitute)) + else: + initial_logits, transition_logits, observation_logits = ( + self.statearrange(decoded['precursor_seq_logits'], + decoded['insert_seq_logits'], + insert_logits, delete_logits)) + # Draw samples. + pyro.sample("obs_seq", + MissingDataDiscreteHMM(initial_logits, + transition_logits, + observation_logits), + obs=seq_data) + + def guide(self, seq_data, local_scale, local_prior_scale): + # Register encoder with pyro. + pyro.module("encoder", self.encoder) + + # ARD weightings. + if self.ARD_prior: + alpha_conc = pyro.param("alpha_conc", torch.randn(self.z_dim)) + alpha_rate = pyro.param("alpha_rate", torch.randn(self.z_dim)) + pyro.sample("alpha", dist.Gamma(softplus(alpha_conc), + softplus(alpha_rate)).to_event(1)) + # Factors. + W_q_mn = pyro.param("W_q_mn", torch.randn([ + self.z_dim, self.total_factor_size])) + W_q_sd = pyro.param("W_q_sd", torch.ones([ + self.z_dim, self.total_factor_size])) + pyro.sample("W", dist.Normal(W_q_mn, softplus(W_q_sd)).to_event(2)) + B_q_mn = pyro.param("B_q_mn", torch.randn(self.total_factor_size)) + B_q_sd = pyro.param("B_q_sd", torch.ones(self.total_factor_size)) + pyro.sample("B", dist.Normal(B_q_mn, softplus(B_q_sd)).to_event(1)) + + # Indel probabilities. + if not self.indel_factor_dependence: + insert_q_mn = pyro.param("insert_q_mn", + torch.ones(self.indel_shape) + * self.indel_prior) + insert_q_sd = pyro.param("insert_q_sd", + torch.zeros(self.indel_shape)) + pyro.sample("insert", dist.Normal( + insert_q_mn, softplus(insert_q_sd)).to_event(3)) + delete_q_mn = pyro.param("delete_q_mn", + torch.ones(self.indel_shape) + * self.indel_prior) + delete_q_sd = pyro.param("delete_q_sd", + torch.zeros(self.indel_shape)) + pyro.sample("delete", dist.Normal( + delete_q_mn, softplus(delete_q_sd)).to_event(3)) + + # Inverse temperature. + inverse_temp_q_mn = pyro.param("inverse_temp_q_mn", torch.tensor(0.)) + inverse_temp_q_sd = pyro.param("inverse_temp_q_sd", torch.tensor(0.)) + pyro.sample("inverse_temp", dist.Normal( + inverse_temp_q_mn, softplus(inverse_temp_q_sd))) + + # Substitution matrix. + if self.substitution_matrix: + substitute_q_mn = pyro.param("substitute_q_mn", torch.zeros( + [self.latent_alphabet_length, self.alphabet_length])) + substitute_q_sd = pyro.param("substitute_q_sd", torch.zeros( + [self.latent_alphabet_length, self.alphabet_length])) + pyro.sample("substitute", dist.Normal( + substitute_q_mn, softplus(substitute_q_sd)).to_event(2)) + + # Per datapoint local latent variables. + with pyro.plate("batch", seq_data.shape[0]): + # Encode sequences. + z_loc, z_scale = self.encoder(seq_data) + # Scale log likelihood to account for mini-batching. + with poutine.scale(scale=local_scale*local_prior_scale): + # Sample. + if self.z_prior_distribution == 'Normal': + pyro.sample("latent", + dist.Normal(z_loc, z_scale).to_event(1)) + elif self.z_prior_distribution == 'Laplace': + pyro.sample("latent", + dist.Laplace(z_loc, z_scale).to_event(1)) + + def fit_svi(self, dataset, epochs=2, anneal_length=1., batch_size=None, + scheduler=None, jit=False): + """ + Infer approximate posterior with stochastic variational inference. + + This runs :class:`~pyro.infer.svi.SVI`. It is an approximate inference + method useful for quickly iterating on probabilistic models. + + :param ~torch.utils.data.Dataset dataset: The training dataset. + :param int epochs: Number of epochs of training. + :param float anneal_length: Number of epochs over which to linearly + anneal the prior KL divergence weight from 0 to 1, for improved + training. + :param int batch_size: Minibatch size (number of sequences). + :param pyro.optim.MultiStepLR scheduler: Optimization scheduler. + (Default: Adam optimizer, 0.01 constant learning rate.) + :param bool jit: Whether to use a jit compiled ELBO. + """ + + # Setup. + if batch_size is not None: + self.batch_size = batch_size + if scheduler is None: + scheduler = MultiStepLR({'optimizer': Adam, + 'optim_args': {'lr': 0.01}, + 'milestones': [], + 'gamma': 0.5}) + dataload = DataLoader(dataset, batch_size=batch_size, shuffle=True, + pin_memory=self.pin_memory) + # Initialize guide. + for seq_data, L_data in dataload: + if self.is_cuda: + seq_data = seq_data.cuda() + self.guide(seq_data, torch.tensor(1.), torch.tensor(1.)) + break + # Setup stochastic variational inference. + if jit: + elbo = JitTrace_ELBO(ignore_jit_warnings=True) + else: + elbo = Trace_ELBO() + svi = SVI(self.model, self.guide, scheduler, loss=elbo) + + # Run inference. + losses = [] + step_i = 1 + t0 = datetime.datetime.now() + for epoch in range(epochs): + for seq_data, L_data in dataload: + if self.is_cuda: + seq_data = seq_data.cuda() + loss = svi.step( + seq_data, torch.tensor(len(dataset)/seq_data.shape[0]), + self._beta_anneal(step_i, batch_size, len(dataset), + anneal_length)) + losses.append(loss) + scheduler.step() + step_i += 1 + print(epoch, loss, ' ', datetime.datetime.now() - t0) + + return losses + + def _beta_anneal(self, step, batch_size, data_size, anneal_length): + """Annealing schedule for prior KL term (beta annealing).""" + if np.allclose(anneal_length, 0.): + return torch.tensor(1.) + anneal_frac = step*batch_size/(anneal_length*data_size) + return torch.tensor(min([anneal_frac, 1.])) + + def evaluate(self, dataset_train, dataset_test=None, jit=False): + """ + Evaluate performance (log probability and per residue perplexity) on + train and test datasets. + + :param ~torch.utils.data.Dataset dataset: The training dataset. + :param ~torch.utils.data.Dataset dataset: The testing dataset + (optional). + :param bool jit: Whether to use a jit compiled ELBO. + """ + dataload_train = DataLoader(dataset_train, batch_size=1, shuffle=False) + if dataset_test is not None: + dataload_test = DataLoader(dataset_test, batch_size=1, + shuffle=False) + # Initialize guide. + for seq_data, L_data in dataload_train: + if self.is_cuda: + seq_data = seq_data.cuda() + self.guide(seq_data, torch.tensor(1.), torch.tensor(1.)) + break + if jit: + elbo = JitTrace_ELBO(ignore_jit_warnings=True) + else: + elbo = Trace_ELBO() + scheduler = MultiStepLR({'optimizer': Adam, 'optim_args': {'lr': 0.01}}) + # Setup stochastic variational inference. + svi = SVI(self.model, self.guide, scheduler, loss=elbo) + + # Compute elbo and perplexity. + train_lp, train_perplex = self._evaluate_local_elbo( + svi, dataload_train, len(dataset_train)) + if dataset_test is not None: + test_lp, test_perplex = self._evaluate_local_elbo( + svi, dataload_test, len(dataset_test)) + return train_lp, test_lp, train_perplex, test_perplex + else: + return train_lp, None, train_perplex, None + + def _local_variables(self, name, site): + """Return per datapoint random variables in model.""" + return name in ['latent', 'obs_L', 'obs_seq'] + + def _evaluate_local_elbo(self, svi, dataload, data_size): + """Evaluate elbo and average per residue perplexity.""" + lp, perplex = 0., 0. + with torch.no_grad(): + for seq_data, L_data in dataload: + if self.is_cuda: + seq_data, L_data = seq_data.cuda(), L_data.cuda() + conditioned_model = poutine.condition(self.model, data={ + "obs_seq": seq_data}) + args = (seq_data, torch.tensor(1.), torch.tensor(1.)) + guide_tr = poutine.trace(self.guide).get_trace(*args) + model_tr = poutine.trace(poutine.replay( + conditioned_model, trace=guide_tr)).get_trace(*args) + local_elbo = (model_tr.log_prob_sum(self._local_variables) + - guide_tr.log_prob_sum(self._local_variables) + ).cpu().numpy() + lp += local_elbo + perplex += -local_elbo / L_data[0].cpu().numpy() + perplex = np.exp(perplex / data_size) + return lp, perplex + + def embed(self, dataset, batch_size=None): + """ + Get the latent space embedding (mean posterior value of z). + + :param ~torch.utils.data.Dataset dataset: The dataset to embed. + :param int batch_size: Minibatch size (number of sequences). (Defaults + to batch_size of the model object.) + """ + if batch_size is None: + batch_size = self.batch_size + dataload = DataLoader(dataset, batch_size=batch_size, shuffle=False) + with torch.no_grad(): + z_locs, z_scales = [], [] + for seq_data, L_data in dataload: + if self.is_cuda: + seq_data = seq_data.cuda() + z_loc, z_scale = self.encoder(seq_data) + z_locs.append(z_loc.cpu()) + z_scales.append(z_scale.cpu()) + + return torch.cat(z_locs), torch.cat(z_scales) + + def _reconstruct_regressor_seq(self, data, ind, param): + "Reconstruct the latent regressor sequence given data." + with torch.no_grad(): + # Encode seq. + z_loc = self.encoder(data[ind][0])[0] + # Reconstruct + decoded = self.decoder(z_loc, param("W_q_mn"), param("B_q_mn"), + param("inverse_temp_q_mn")) + return torch.exp(decoded['precursor_seq_logits']) diff --git a/pyro/contrib/mue/statearrangers.py b/pyro/contrib/mue/statearrangers.py new file mode 100644 index 0000000000..2c384ec720 --- /dev/null +++ b/pyro/contrib/mue/statearrangers.py @@ -0,0 +1,209 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torch.nn as nn + + +class Profile(nn.Module): + """ + Profile HMM state arrangement. Parameterizes an HMM according to + Equation S40 in [1] (with r_{M+1,j} = 1 and u_{M+1,j} = 0 + for j in {0, 1, 2}). For further background on profile HMMs see [2]. + + **References** + + [1] E. N. Weinstein, D. S. Marks (2021) + "Generative probabilistic biological sequence models that account for + mutational variability" + https://www.biorxiv.org/content/10.1101/2020.07.31.231381v2.full.pdf + + [2] R. Durbin, S. R. Eddy, A. Krogh, and G. Mitchison (1998) + "Biological sequence analysis: probabilistic models of proteins and nucleic + acids" + Cambridge university press + + :param M: Length of regressor sequence. + :type M: int + :param epsilon: A small value for numerical stability. + :type epsilon: float + """ + def __init__(self, M, epsilon=1e-32): + super().__init__() + self.M = M + self.K = 2*M+1 + self.epsilon = epsilon + + self._make_transfer() + + def _make_transfer(self): + """Set up linear transformations (transfer matrices) for converting + from profile HMM parameters to standard HMM parameters.""" + M, K = self.M, self.K + + # Overview: + # r -> insertion parameters + # u -> deletion parameters + # indices: m in {0, ..., M} and j in {0, 1, 2}; final index corresponds + # to simplex dimensions, i.e. 1 - r and r (in that order) + # null -> locations in the transition matrix equal to 0 + # ...transf_0 -> initial transition vector + # ...transf -> transition matrix + # We fix r_{M+1,j} = 1 for j in {0, 1, 2} + self.register_buffer('r_transf_0', + torch.zeros((M, 3, 2, K))) + self.register_buffer('u_transf_0', + torch.zeros((M, 3, 2, K))) + self.register_buffer('null_transf_0', + torch.zeros((K,))) + m, g = -1, 0 + for gp in range(2): + for mp in range(M+gp): + kp = mg2k(mp, gp, M) + if m + 1 - g == mp and gp == 0: + self.r_transf_0[m+1-g, g, 0, kp] = 1 + self.u_transf_0[m+1-g, g, 0, kp] = 1 + + elif m + 1 - g < mp and gp == 0: + self.r_transf_0[m+1-g, g, 0, kp] = 1 + self.u_transf_0[m+1-g, g, 1, kp] = 1 + for mpp in range(m+2-g, mp): + self.r_transf_0[mpp, 2, 0, kp] = 1 + self.u_transf_0[mpp, 2, 1, kp] = 1 + self.r_transf_0[mp, 2, 0, kp] = 1 + self.u_transf_0[mp, 2, 0, kp] = 1 + + elif m + 1 - g == mp and gp == 1: + if mp < M: + self.r_transf_0[m+1-g, g, 1, kp] = 1 + + elif m + 1 - g < mp and gp == 1: + self.r_transf_0[m+1-g, g, 0, kp] = 1 + self.u_transf_0[m+1-g, g, 1, kp] = 1 + for mpp in range(m+2-g, mp): + self.r_transf_0[mpp, 2, 0, kp] = 1 + self.u_transf_0[mpp, 2, 1, kp] = 1 + if mp < M: + self.r_transf_0[mp, 2, 1, kp] = 1 + + else: + self.null_transf_0[kp] = 1 + + self.register_buffer('r_transf', + torch.zeros((M, 3, 2, K, K))) + self.register_buffer('u_transf', + torch.zeros((M, 3, 2, K, K))) + self.register_buffer('null_transf', + torch.zeros((K, K))) + for g in range(2): + for m in range(M+g): + for gp in range(2): + for mp in range(M+gp): + k, kp = mg2k(m, g, M), mg2k(mp, gp, M) + if m + 1 - g == mp and gp == 0: + self.r_transf[m+1-g, g, 0, k, kp] = 1 + self.u_transf[m+1-g, g, 0, k, kp] = 1 + + elif m + 1 - g < mp and gp == 0: + self.r_transf[m+1-g, g, 0, k, kp] = 1 + self.u_transf[m+1-g, g, 1, k, kp] = 1 + self.r_transf[(m+2-g):mp, 2, 0, k, kp] = 1 + self.u_transf[(m+2-g):mp, 2, 1, k, kp] = 1 + self.r_transf[mp, 2, 0, k, kp] = 1 + self.u_transf[mp, 2, 0, k, kp] = 1 + + elif m + 1 - g == mp and gp == 1: + if mp < M: + self.r_transf[m+1-g, g, 1, k, kp] = 1 + + elif m + 1 - g < mp and gp == 1: + self.r_transf[m+1-g, g, 0, k, kp] = 1 + self.u_transf[m+1-g, g, 1, k, kp] = 1 + self.r_transf[(m+2-g):mp, 2, 0, k, kp] = 1 + self.u_transf[(m+2-g):mp, 2, 1, k, kp] = 1 + if mp < M: + self.r_transf[mp, 2, 1, k, kp] = 1 + + else: + self.null_transf[k, kp] = 1 + + self.register_buffer('vx_transf', + torch.zeros((M, K))) + self.register_buffer('vc_transf', + torch.zeros((M+1, K))) + for g in range(2): + for m in range(M+g): + k = mg2k(m, g, M) + if g == 0: + self.vx_transf[m, k] = 1 + elif g == 1: + self.vc_transf[m, k] = 1 + + def forward(self, precursor_seq_logits, insert_seq_logits, + insert_logits, delete_logits, substitute_logits=None): + """ + Assemble HMM parameters given profile parameters. + + :param ~torch.Tensor precursor_seq_logits: Regressor sequence + *log(x)*. Should have rightmost dimension ``(M, D)`` and be + broadcastable to ``(batch_size, M, D)``, where + D is the latent alphabet size. Should be normalized to one along the + final axis, i.e. ``precursor_seq_logits.logsumexp(-1) = zeros``. + :param ~torch.Tensor insert_seq_logits: Insertion sequence *log(c)*. + Should have rightmost dimension ``(M+1, D)`` and be broadcastable + to ``(batch_size, M+1, D)``. Should be normalized + along the final axis. + :param ~torch.Tensor insert_logits: Insertion probabilities *log(r)*. + Should have rightmost dimension ``(M, 3, 2)`` and be broadcastable + to ``(batch_size, M, 3, 2)``. Should be normalized along the + final axis. + :param ~torch.Tensor delete_logits: Deletion probabilities *log(u)*. + Should have rightmost dimension ``(M, 3, 2)`` and be broadcastable + to ``(batch_size, M, 3, 2)``. Should be normalized along the + final axis. + :param ~torch.Tensor substitute_logits: Substitution probabilities + *log(l)*. Should have rightmost dimension ``(D, B)``, where + B is the alphabet size of the data, and broadcastable to + ``(batch_size, D, B)``. Must be normalized along the + final axis. + :return: *initial_logits*, *transition_logits*, and + *observation_logits*. These parameters can be used to directly + initialize the MissingDataDiscreteHMM distribution. + :rtype: ~torch.Tensor, ~torch.Tensor, ~torch.Tensor + """ + initial_logits = ( + torch.einsum('...ijk,ijkl->...l', + delete_logits, self.u_transf_0) + + torch.einsum('...ijk,ijkl->...l', + insert_logits, self.r_transf_0) + + (-1/self.epsilon)*self.null_transf_0) + transition_logits = ( + torch.einsum('...ijk,ijklf->...lf', + delete_logits, self.u_transf) + + torch.einsum('...ijk,ijklf->...lf', + insert_logits, self.r_transf) + + (-1/self.epsilon)*self.null_transf) + # Broadcasting for concatenation. + if len(precursor_seq_logits.size()) > len(insert_seq_logits.size()): + insert_seq_logits = insert_seq_logits.unsqueeze(0).expand( + [precursor_seq_logits.size()[0], -1, -1]) + elif len(insert_seq_logits.size()) > len(precursor_seq_logits.size()): + precursor_seq_logits = precursor_seq_logits.unsqueeze(0).expand( + [insert_seq_logits.size()[0], -1, -1]) + seq_logits = torch.cat([precursor_seq_logits, insert_seq_logits], + dim=-2) + + # Option to include the substitution matrix. + if substitute_logits is not None: + observation_logits = torch.logsumexp( + seq_logits.unsqueeze(-1) + substitute_logits.unsqueeze(-3), + dim=-2) + else: + observation_logits = seq_logits + + return initial_logits, transition_logits, observation_logits + + +def mg2k(m, g, M): + """Convert from (m, g) indexing to k indexing.""" + return m + M*g diff --git a/tests/contrib/mue/test_dataloaders.py b/tests/contrib/mue/test_dataloaders.py new file mode 100644 index 0000000000..94ba2fa02a --- /dev/null +++ b/tests/contrib/mue/test_dataloaders.py @@ -0,0 +1,69 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch + +from pyro.contrib.mue.dataloaders import BiosequenceDataset, alphabets + + +@pytest.mark.parametrize('source_type', ['list', 'fasta']) +@pytest.mark.parametrize('alphabet', ['amino-acid', 'dna', 'ATC']) +@pytest.mark.parametrize('include_stop', [False, True]) +def test_biosequencedataset(source_type, alphabet, include_stop): + + # Define dataset. + seqs = ['AATC', 'CA', 'T'] + + # Encode dataset, alternate approach. + if alphabet in alphabets: + alphabet_list = list(alphabets[alphabet]) + include_stop*['*'] + else: + alphabet_list = list(alphabet) + include_stop*['*'] + L_data_check = [len(seq) + include_stop for seq in seqs] + max_length_check = max(L_data_check) + data_size_check = len(seqs) + seq_data_check = torch.zeros([len(seqs), max_length_check, + len(alphabet_list)]) + for i in range(len(seqs)): + for j, s in enumerate(seqs[i] + include_stop*'*'): + seq_data_check[i, j, list(alphabet_list).index(s)] = 1 + + # Setup data source. + if source_type == 'fasta': + # Save as external file. + source = 'test_seqs.fasta' + with open(source, 'w') as fw: + text = """>one +AAT +C +>two +CA +>three +T +""" + fw.write(text) + elif source_type == 'list': + source = seqs + + # Load dataset. + dataset = BiosequenceDataset(source, source_type, alphabet, + include_stop=include_stop) + + # Check. + assert torch.allclose(dataset.L_data, + torch.tensor(L_data_check, dtype=torch.float64)) + assert dataset.max_length == max_length_check + assert len(dataset) == data_size_check + assert dataset.data_size == data_size_check + assert dataset.alphabet_length == len(alphabet_list) + assert torch.allclose(dataset.seq_data, seq_data_check) + ind = torch.tensor([0, 2]) + assert torch.allclose(dataset[ind][0], + torch.cat([seq_data_check[0, None, :, :], + seq_data_check[2, None, :, :]])) + assert torch.allclose(dataset[ind][1], torch.tensor([4. + include_stop, + 1. + include_stop])) + dataload = torch.utils.data.DataLoader(dataset, batch_size=2) + for seq_data, L_data in dataload: + assert seq_data.shape[0] == L_data.shape[0] diff --git a/tests/contrib/mue/test_missingdatahmm.py b/tests/contrib/mue/test_missingdatahmm.py new file mode 100644 index 0000000000..ee19f4b31d --- /dev/null +++ b/tests/contrib/mue/test_missingdatahmm.py @@ -0,0 +1,173 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch + +from pyro.contrib.mue.missingdatahmm import MissingDataDiscreteHMM +from pyro.distributions import Categorical, DiscreteHMM + + +def test_hmm_log_prob(): + + a0 = torch.tensor([0.9, 0.08, 0.02]) + a = torch.tensor([[0.1, 0.8, 0.1], [0.5, 0.3, 0.2], [0.4, 0.4, 0.2]]) + e = torch.tensor([[0.99, 0.01], [0.01, 0.99], [0.5, 0.5]]) + + x = torch.tensor([[0., 1.], + [1., 0.], + [0., 1.], + [0., 1.], + [1., 0.], + [0., 0.]]) + + hmm_distr = MissingDataDiscreteHMM(torch.log(a0), torch.log(a), + torch.log(e)) + lp = hmm_distr.log_prob(x) + + f = a0 * e[:, 1] + f = torch.matmul(f, a) * e[:, 0] + f = torch.matmul(f, a) * e[:, 1] + f = torch.matmul(f, a) * e[:, 1] + f = torch.matmul(f, a) * e[:, 0] + chk_lp = torch.log(torch.sum(f)) + + assert torch.allclose(lp, chk_lp) + + # Batch values. + x = torch.cat([ + x[None, :, :], + torch.tensor([[1., 0.], + [1., 0.], + [1., 0.], + [0., 0.], + [0., 0.], + [0., 0.]])[None, :, :]], dim=0) + lp = hmm_distr.log_prob(x) + + f = a0 * e[:, 0] + f = torch.matmul(f, a) * e[:, 0] + f = torch.matmul(f, a) * e[:, 0] + chk_lp = torch.cat([chk_lp[None], torch.log(torch.sum(f))[None]]) + + assert torch.allclose(lp, chk_lp) + + # Batch both parameters and values. + a0 = torch.cat([a0[None, :], torch.tensor([0.2, 0.7, 0.1])[None, :]]) + a = torch.cat([ + a[None, :, :], + torch.tensor([[0.8, 0.1, 0.1], [0.2, 0.6, 0.2], [0.1, 0.1, 0.8]] + )[None, :, :]], dim=0) + e = torch.cat([ + e[None, :, :], + torch.tensor([[0.4, 0.6], [0.99, 0.01], [0.7, 0.3]])[None, :, :]], + dim=0) + hmm_distr = MissingDataDiscreteHMM(torch.log(a0), torch.log(a), + torch.log(e)) + lp = hmm_distr.log_prob(x) + + f = a0[1, :] * e[1, :, 0] + f = torch.matmul(f, a[1, :, :]) * e[1, :, 0] + f = torch.matmul(f, a[1, :, :]) * e[1, :, 0] + chk_lp = torch.cat([chk_lp[0][None], torch.log(torch.sum(f))[None]]) + + assert torch.allclose(lp, chk_lp) + + +@pytest.mark.parametrize('batch_initial', [False, True]) +@pytest.mark.parametrize('batch_transition', [False, True]) +@pytest.mark.parametrize('batch_observation', [False, True]) +@pytest.mark.parametrize('batch_data', [False, True]) +def test_shapes(batch_initial, batch_transition, batch_observation, batch_data): + + # Dimensions. + batch_size = 3 + state_dim, observation_dim, num_steps = 4, 5, 6 + + # Model initialization. + initial_logits = torch.randn([batch_size]*batch_initial + [state_dim]) + initial_logits = (initial_logits - + initial_logits.logsumexp(-1, True)) + transition_logits = torch.randn([batch_size]*batch_transition + + [state_dim, state_dim]) + transition_logits = (transition_logits - + transition_logits.logsumexp(-1, True)) + observation_logits = torch.randn([batch_size]*batch_observation + + [state_dim, observation_dim]) + observation_logits = (observation_logits - + observation_logits.logsumexp(-1, True)) + + hmm = MissingDataDiscreteHMM(initial_logits, transition_logits, + observation_logits) + + # Random observations. + value = (torch.randint(observation_dim, + [batch_size]*batch_data + [num_steps]).unsqueeze(-1) + == torch.arange(observation_dim)).double() + + # Log probability. + lp = hmm.log_prob(value) + + # Check shapes: + if all([not batch_initial, not batch_transition, not batch_observation, + not batch_data]): + assert lp.shape == () + else: + assert lp.shape == (batch_size,) + + +@pytest.mark.parametrize('batch_initial', [False, True]) +@pytest.mark.parametrize('batch_transition', [False, True]) +@pytest.mark.parametrize('batch_observation', [False, True]) +@pytest.mark.parametrize('batch_data', [False, True]) +def test_DiscreteHMM_comparison(batch_initial, batch_transition, + batch_observation, batch_data): + # Dimensions. + batch_size = 3 + state_dim, observation_dim, num_steps = 4, 5, 6 + + # -- Model setup --. + transition_logits_vldhmm = torch.randn([batch_size]*batch_transition + + [state_dim, state_dim]) + transition_logits_vldhmm = (transition_logits_vldhmm - + transition_logits_vldhmm.logsumexp(-1, True)) + # Adjust for DiscreteHMM broadcasting convention. + transition_logits_dhmm = transition_logits_vldhmm.unsqueeze(-3) + # Convert between discrete HMM convention for initial state and variable + # length HMM convention. + initial_logits_dhmm = torch.randn([batch_size]*batch_initial + [state_dim]) + initial_logits_dhmm = (initial_logits_dhmm - + initial_logits_dhmm.logsumexp(-1, True)) + initial_logits_vldhmm = (initial_logits_dhmm.unsqueeze(-1) + + transition_logits_vldhmm).logsumexp(-2) + observation_logits = torch.randn([batch_size]*batch_observation + + [state_dim, observation_dim]) + observation_logits = (observation_logits - + observation_logits.logsumexp(-1, True)) + # Create distribution object for DiscreteHMM + observation_dist = Categorical(logits=observation_logits.unsqueeze(-3)) + + vldhmm = MissingDataDiscreteHMM(initial_logits_vldhmm, + transition_logits_vldhmm, + observation_logits) + dhmm = DiscreteHMM(initial_logits_dhmm, transition_logits_dhmm, + observation_dist) + + # Random observations. + value = torch.randint(observation_dim, + [batch_size]*batch_data + [num_steps]) + value_oh = (value.unsqueeze(-1) + == torch.arange(observation_dim)).double() + + # -- Check. -- + # Log probability. + lp_vldhmm = vldhmm.log_prob(value_oh) + lp_dhmm = dhmm.log_prob(value) + # Shapes. + if all([not batch_initial, not batch_transition, not batch_observation, + not batch_data]): + assert lp_vldhmm.shape == () + else: + assert lp_vldhmm.shape == (batch_size,) + # Values. + assert torch.allclose(lp_vldhmm, lp_dhmm) diff --git a/tests/contrib/mue/test_models.py b/tests/contrib/mue/test_models.py new file mode 100644 index 0000000000..5f2ba9634b --- /dev/null +++ b/tests/contrib/mue/test_models.py @@ -0,0 +1,93 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest +import torch +from torch.optim import Adam + +import pyro +from pyro.contrib.mue.dataloaders import BiosequenceDataset +from pyro.contrib.mue.models import FactorMuE, ProfileHMM +from pyro.optim import MultiStepLR + + +@pytest.mark.parametrize('jit', [False, True]) +def test_ProfileHMM_smoke(jit): + # Setup dataset. + seqs = ['BABBA', 'BAAB', 'BABBB'] + alph = 'AB' + dataset = BiosequenceDataset(seqs, 'list', alph) + + # Infer. + scheduler = MultiStepLR({'optimizer': Adam, + 'optim_args': {'lr': 0.1}, + 'milestones': [20, 100, 1000, 2000], + 'gamma': 0.5}) + model = ProfileHMM(int(dataset.max_length*1.1), dataset.alphabet_length) + n_epochs = 5 + batch_size = 2 + losses = model.fit_svi(dataset, n_epochs, batch_size, scheduler, jit) + + assert not np.isnan(losses[-1]) + + # Evaluate. + train_lp, test_lp, train_perplex, test_perplex = model.evaluate( + dataset, dataset, jit) + assert train_lp < 0. + assert test_lp < 0. + assert train_perplex > 0. + assert test_perplex > 0. + + +@pytest.mark.parametrize('indel_factor_dependence', [False, True]) +@pytest.mark.parametrize('z_prior_distribution', ['Normal', 'Laplace']) +@pytest.mark.parametrize('ARD_prior', [False, True]) +@pytest.mark.parametrize('substitution_matrix', [False, True]) +@pytest.mark.parametrize('jit', [False, True]) +def test_FactorMuE_smoke(indel_factor_dependence, z_prior_distribution, + ARD_prior, substitution_matrix, jit): + # Setup dataset. + seqs = ['BABBA', 'BAAB', 'BABBB'] + alph = 'AB' + dataset = BiosequenceDataset(seqs, 'list', alph) + + # Infer. + z_dim = 2 + scheduler = MultiStepLR({'optimizer': Adam, + 'optim_args': {'lr': 0.1}, + 'milestones': [20, 100, 1000, 2000], + 'gamma': 0.5}) + model = FactorMuE(dataset.max_length, dataset.alphabet_length, z_dim, + indel_factor_dependence=indel_factor_dependence, + z_prior_distribution=z_prior_distribution, + ARD_prior=ARD_prior, + substitution_matrix=substitution_matrix) + n_epochs = 5 + anneal_length = 2 + batch_size = 2 + losses = model.fit_svi(dataset, n_epochs, anneal_length, batch_size, + scheduler, jit) + + # Reconstruct. + recon = model._reconstruct_regressor_seq(dataset, 1, pyro.param) + + assert not np.isnan(losses[-1]) + assert recon.shape == (1, max([len(seq) for seq in seqs]), len(alph)) + + assert torch.allclose(model._beta_anneal(3, 2, 6, 2), torch.tensor(0.5)) + assert torch.allclose(model._beta_anneal(100, 2, 6, 2), torch.tensor(1.)) + + # Evaluate. + train_lp, test_lp, train_perplex, test_perplex = model.evaluate( + dataset, dataset, jit) + assert train_lp < 0. + assert test_lp < 0. + assert train_perplex > 0. + assert test_perplex > 0. + + # Embedding. + z_locs, z_scales = model.embed(dataset) + assert z_locs.shape == (len(dataset), z_dim) + assert z_scales.shape == (len(dataset), z_dim) + assert torch.all(z_scales > 0.) diff --git a/tests/contrib/mue/test_statearrangers.py b/tests/contrib/mue/test_statearrangers.py new file mode 100644 index 0000000000..1fd4f672a9 --- /dev/null +++ b/tests/contrib/mue/test_statearrangers.py @@ -0,0 +1,234 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch + +from pyro.contrib.mue.statearrangers import Profile, mg2k + + +def simpleprod(lst): + # Product of list of scalar tensors, as numpy would do it. + if len(lst) == 0: + return torch.tensor(1.) + else: + return torch.prod(torch.cat([elem[None] for elem in lst])) + + +@pytest.mark.parametrize('M', [2, 20]) +@pytest.mark.parametrize('batch_size', [None, 5]) +@pytest.mark.parametrize('substitute', [False, True]) +def test_profile_alternate_imp(M, batch_size, substitute): + + # --- Setup random model. --- + pf_arranger = Profile(M) + + u1 = torch.rand((M+1, 3)) + u1[M, :] = 0 # Assume u_{M+1, j} = 0 for j in {0, 1, 2} in Eqn. S40. + u = torch.cat([(1-u1)[:, :, None], u1[:, :, None]], dim=2) + r1 = torch.rand((M+1, 3)) + r1[M, :] = 1 # Assume r_{M+1, j} = 1 for j in {0, 1, 2} in Eqn. S40. + r = torch.cat([(1-r1)[:, :, None], r1[:, :, None]], dim=2) + s = torch.rand((M, 4)) + s = s/torch.sum(s, dim=1, keepdim=True) + c = torch.rand((M+1, 4)) + c = c/torch.sum(c, dim=1, keepdim=True) + + if batch_size is not None: + s = torch.rand((batch_size, M, 4)) + s = s/torch.sum(s, dim=2, keepdim=True) + u1 = torch.rand((batch_size, M+1, 3)) + u1[:, M, :] = 0 + u = torch.cat([(1-u1)[:, :, :, None], u1[:, :, :, None]], dim=3) + + # Compute forward pass of state arranger to get HMM parameters. + # Don't use dimension M, assumed fixed by statearranger. + if substitute: + ll = torch.rand((4, 5)) + ll = ll/torch.sum(ll, dim=1, keepdim=True) + a0ln, aln, eln = pf_arranger.forward( + torch.log(s), torch.log(c), + torch.log(r[:-1, :]), torch.log(u[..., :-1, :, :]), + torch.log(ll)) + else: + a0ln, aln, eln = pf_arranger.forward( + torch.log(s), torch.log(c), + torch.log(r[:-1, :]), torch.log(u[..., :-1, :, :])) + + # - Remake HMM parameters to check. - + # Here we implement Equation S40 from the MuE paper + # (https://www.biorxiv.org/content/10.1101/2020.07.31.231381v1.full.pdf) + # more directly, iterating over all the indices of the transition matrix + # and initial transition vector. + K = 2*M + 1 + if batch_size is None: + batch_dim_size = 1 + r1 = r1.unsqueeze(0) + u1 = u1.unsqueeze(0) + s = s.unsqueeze(0) + c = c.unsqueeze(0) + if substitute: + ll = ll.unsqueeze(0) + else: + batch_dim_size = batch_size + r1 = r1[None, :, :] * torch.ones([batch_size, 1, 1]) + c = c[None, :, :] * torch.ones([batch_size, 1, 1]) + if substitute: + ll = ll.unsqueeze(0) + chk_a = torch.zeros((batch_dim_size, K, K)) + chk_a0 = torch.zeros((batch_dim_size, K)) + chk_e = torch.zeros((batch_dim_size, K, 4)) + for b in range(batch_dim_size): + m, g = -1, 0 + u1[b][-1] = 1e-32 + for gp in range(2): + for mp in range(M+gp): + kp = mg2k(mp, gp, M) + if m + 1 - g == mp and gp == 0: + chk_a0[b, kp] = (1 - r1[b, m+1-g, g])*(1 - u1[b, m+1-g, g]) + elif m + 1 - g < mp and gp == 0: + chk_a0[b, kp] = ( + (1 - r1[b, m+1-g, g]) * u1[b, m+1-g, g] * + simpleprod([(1 - r1[b, mpp, 2])*u1[b, mpp, 2] + for mpp in + range(m+2-g, mp)]) * + (1 - r1[b, mp, 2]) * (1 - u1[b, mp, 2])) + elif m + 1 - g == mp and gp == 1: + chk_a0[b, kp] = r1[b, m+1-g, g] + elif m + 1 - g < mp and gp == 1: + chk_a0[b, kp] = ( + (1 - r1[b, m+1-g, g]) * u1[b, m+1-g, g] * + simpleprod([(1 - r1[b, mpp, 2])*u1[b, mpp, 2] + for mpp in + range(m+2-g, mp)]) * r1[b, mp, 2]) + for g in range(2): + for m in range(M+g): + k = mg2k(m, g, M) + for gp in range(2): + for mp in range(M+gp): + kp = mg2k(mp, gp, M) + if m + 1 - g == mp and gp == 0: + chk_a[b, k, kp] = (1 - r1[b, m+1-g, g] + )*(1 - u1[b, m+1-g, g]) + elif m + 1 - g < mp and gp == 0: + chk_a[b, k, kp] = ( + (1 - r1[b, m+1-g, g]) * u1[b, m+1-g, g] * + simpleprod([(1 - r1[b, mpp, 2]) * + u1[b, mpp, 2] + for mpp in range(m+2-g, mp)]) * + (1 - r1[b, mp, 2]) * (1 - u1[b, mp, 2])) + elif m + 1 - g == mp and gp == 1: + chk_a[b, k, kp] = r1[b, m+1-g, g] + elif m + 1 - g < mp and gp == 1: + chk_a[b, k, kp] = ( + (1 - r1[b, m+1-g, g]) * u1[b, m+1-g, g] * + simpleprod([(1 - r1[b, mpp, 2]) * + u1[b, mpp, 2] + for mpp in + range(m+2-g, mp)] + ) * r1[b, mp, 2]) + elif m == M and mp == M and g == 0 and gp == 0: + chk_a[b, k, kp] = 1. + + for g in range(2): + for m in range(M+g): + k = mg2k(m, g, M) + if g == 0: + chk_e[b, k, :] = s[b, m, :] + else: + chk_e[b, k, :] = c[b, m, :] + if substitute: + chk_e = torch.matmul(chk_e, ll) + + # --- Check --- + if batch_size is None: + chk_a = chk_a.squeeze() + chk_a0 = chk_a0.squeeze() + chk_e = chk_e.squeeze() + + assert torch.allclose(torch.sum(torch.exp(a0ln)), torch.tensor(1.), + atol=1e-3, rtol=1e-3) + assert torch.allclose(torch.sum(torch.exp(aln), axis=1), + torch.ones(2*M+1), atol=1e-3, + rtol=1e-3) + assert torch.allclose(chk_a0, torch.exp(a0ln)) + assert torch.allclose(chk_a, torch.exp(aln)) + assert torch.allclose(chk_e, torch.exp(eln)) + + +@pytest.mark.parametrize('batch_ancestor_seq', [False, True]) +@pytest.mark.parametrize('batch_insert_seq', [False, True]) +@pytest.mark.parametrize('batch_insert', [False, True]) +@pytest.mark.parametrize('batch_delete', [False, True]) +@pytest.mark.parametrize('batch_substitute', [False, True]) +def test_profile_shapes(batch_ancestor_seq, batch_insert_seq, batch_insert, + batch_delete, batch_substitute): + + M, D, B = 5, 2, 3 + K = 2*M + 1 + batch_size = 6 + pf_arranger = Profile(M) + sln = torch.randn([batch_size]*batch_ancestor_seq + [M, D]) + sln = sln - sln.logsumexp(-1, True) + cln = torch.randn([batch_size]*batch_insert_seq + [M+1, D]) + cln = cln - cln.logsumexp(-1, True) + rln = torch.randn([batch_size]*batch_insert + [M, 3, 2]) + rln = rln - rln.logsumexp(-1, True) + uln = torch.randn([batch_size]*batch_delete + [M, 3, 2]) + uln = uln - uln.logsumexp(-1, True) + lln = torch.randn([batch_size]*batch_substitute + [D, B]) + lln = lln - lln.logsumexp(-1, True) + a0ln, aln, eln = pf_arranger.forward(sln, cln, rln, uln, lln) + + if all([not batch_ancestor_seq, not batch_insert_seq, + not batch_substitute]): + assert eln.shape == (K, B) + assert torch.allclose(eln.logsumexp(-1), torch.zeros(K)) + else: + assert eln.shape == (batch_size, K, B) + assert torch.allclose(eln.logsumexp(-1), torch.zeros(batch_size, K)) + + if all([not batch_insert, not batch_delete]): + assert a0ln.shape == (K,) + assert torch.allclose(a0ln.logsumexp(-1), torch.zeros(1)) + assert aln.shape == (K, K) + assert torch.allclose(aln.logsumexp(-1), torch.zeros(K)) + else: + assert a0ln.shape == (batch_size, K) + assert torch.allclose(a0ln.logsumexp(-1), torch.zeros(batch_size)) + assert aln.shape == (batch_size, K, K) + assert torch.allclose(aln.logsumexp(-1), torch.zeros((batch_size, K))) + + +@pytest.mark.parametrize('M', [2, 20]) # , 20 +def test_profile_trivial_cases(M): + # Trivial case: indel probabability of zero. Expected value of + # HMM should match ancestral sequence times substitution matrix. + + # --- Setup model. --- + D, B = 2, 2 + batch_size = 5 + pf_arranger = Profile(M) + sln = torch.randn([batch_size, M, D]) + sln = sln - sln.logsumexp(-1, True) + cln = torch.randn([batch_size, M+1, D]) + cln = cln - cln.logsumexp(-1, True) + rln = torch.cat([torch.zeros([M, 3, 1]), + -1/pf_arranger.epsilon*torch.ones([M, 3, 1])], axis=-1) + uln = torch.cat([torch.zeros([M, 3, 1]), + -1/pf_arranger.epsilon*torch.ones([M, 3, 1])], axis=-1) + lln = torch.randn([D, B]) + lln = lln - lln.logsumexp(-1, True) + + a0ln, aln, eln = pf_arranger.forward(sln, cln, rln, uln, lln) + + # --- Compute expected value per step. --- + Eyln = torch.zeros([batch_size, M, B]) + ai = a0ln + for j in range(M): + Eyln[:, j, :] = torch.logsumexp(ai.unsqueeze(-1) + eln, axis=-2) + ai = torch.logsumexp(ai.unsqueeze(-1) + aln, axis=-2) + + print(aln.exp()) + no_indel = torch.logsumexp(sln.unsqueeze(-1) + lln.unsqueeze(-3), axis=-2) + assert torch.allclose(Eyln, no_indel) diff --git a/tests/test_examples.py b/tests/test_examples.py index 374f3563f8..9298a78f11 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -59,6 +59,10 @@ 'contrib/forecast/bart.py --num-steps=2 --stride=99999', 'contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4 --batch-size=1000', 'contrib/gp/sv-dkl.py --binary --epochs=1 --num-inducing=4 --batch-size=1000', + 'contrib/mue/FactorMuE.py --test --small --include-stop --no-plots --no-save', + 'contrib/mue/FactorMuE.py --test --small -ard -idfac --no-substitution-matrix --no-plots --no-save', + 'contrib/mue/ProfileHMM.py --test --small --no-plots --no-save', + 'contrib/mue/ProfileHMM.py --test --small --include-stop --no-plots --no-save', 'contrib/oed/ab_test.py --num-vi-steps=10 --num-bo-steps=2', 'contrib/timeseries/gp_models.py -m imgp --test --num-steps=2', 'contrib/timeseries/gp_models.py -m lcmgp --test --num-steps=2', @@ -142,6 +146,10 @@ 'contrib/epidemiology/regional.py --nojit -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2 --cuda', 'contrib/epidemiology/regional.py --nojit -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2 --haar --cuda', 'contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4 --cuda', + 'contrib/mue/FactorMuE.py --test --small --include-stop --no-plots --no-save --cuda --cpu-data --pin-mem', + 'contrib/mue/FactorMuE.py --test --small -ard -idfac --no-substitution-matrix --no-plots --no-save --cuda', + 'contrib/mue/ProfileHMM.py --test --small --no-plots --no-save --cuda --cpu-data --pin-mem', + 'contrib/mue/ProfileHMM.py --test --small --include-stop --no-plots --no-save --cuda', 'lkj.py --n=50 --num-chains=1 --warmup-steps=100 --num-samples=200 --cuda', 'dmm.py --num-epochs=1 --cuda', 'dmm.py --num-epochs=1 --num-iafs=1 --cuda', @@ -210,6 +218,10 @@ def xfail_jit(*args, **kwargs): 'contrib/epidemiology/regional.py --jit -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2', 'contrib/epidemiology/regional.py --jit -ss=2 -n=4 -r=3 -d=20 -p=1000 -f 2 --svi', xfail_jit('contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4 --jit'), + 'contrib/mue/FactorMuE.py --test --small --include-stop --no-plots --no-save --jit', + 'contrib/mue/FactorMuE.py --test --small -ard -idfac --no-substitution-matrix --no-plots --no-save --jit', + 'contrib/mue/ProfileHMM.py --test --small --no-plots --no-save --jit', + 'contrib/mue/ProfileHMM.py --test --small --include-stop --no-plots --no-save --jit', xfail_jit('dmm.py --num-epochs=1 --jit'), xfail_jit('dmm.py --num-epochs=1 --num-iafs=1 --jit'), 'eight_schools/mcmc.py --num-samples=500 --warmup-steps=100 --jit', diff --git a/tutorial/source/index.rst b/tutorial/source/index.rst index d0bf0b113d..9ddc65b2b0 100644 --- a/tutorial/source/index.rst +++ b/tutorial/source/index.rst @@ -172,6 +172,14 @@ List of Tutorials epi_regional sir_hmc +.. toctree:: + :maxdepth: 1 + :caption: Application: Biological sequences + :name: biological-sequences + + mue_profile + mue_factor + .. toctree:: :maxdepth: 1 :caption: Application: Experimental Design diff --git a/tutorial/source/mue_factor.rst b/tutorial/source/mue_factor.rst new file mode 100644 index 0000000000..d4ec19ae4b --- /dev/null +++ b/tutorial/source/mue_factor.rst @@ -0,0 +1,11 @@ +Example: Probabilistic PCA + MuE (FactorMuE) +============================================ + +`View FactorHMM.py on github`__ + +.. _github: https://github.com/pyro-ppl/pyro/blob/dev/examples/contrib/mue/FactorMuE.py + +__ github_ + +.. literalinclude:: ../../examples/contrib/mue/FactorMuE.py + :language: python diff --git a/tutorial/source/mue_profile.rst b/tutorial/source/mue_profile.rst new file mode 100644 index 0000000000..41d1b704f0 --- /dev/null +++ b/tutorial/source/mue_profile.rst @@ -0,0 +1,11 @@ +Example: Constant + MuE (Profile HMM) +===================================== + +`View ProfileHMM.py on github`__ + +.. _github: https://github.com/pyro-ppl/pyro/blob/dev/examples/contrib/mue/ProfileHMM.py + +__ github_ + +.. literalinclude:: ../../examples/contrib/mue/ProfileHMM.py + :language: python