From b7049456a6ef06b95f41b4025949d355b02fd1c4 Mon Sep 17 00:00:00 2001 From: Eli Weinstein Date: Thu, 22 Jul 2021 17:19:41 -0400 Subject: [PATCH] Sampling functions for the MuE/missing data discrete HMM. (#2898) --- examples/contrib/mue/FactorMuE.py | 24 +- examples/contrib/mue/ProfileHMM.py | 13 +- pyro/contrib/mue/dataloaders.py | 57 ++++ pyro/contrib/mue/missingdatahmm.py | 205 ++++++++++++ pyro/contrib/mue/models.py | 35 +- tests/contrib/mue/test_dataloaders.py | 48 ++- tests/contrib/mue/test_missingdatahmm.py | 394 ++++++++++++++++++++++- tests/contrib/mue/test_statearrangers.py | 42 +-- 8 files changed, 773 insertions(+), 45 deletions(-) diff --git a/examples/contrib/mue/FactorMuE.py b/examples/contrib/mue/FactorMuE.py index 4bf8f24440..1a06e97fcc 100644 --- a/examples/contrib/mue/FactorMuE.py +++ b/examples/contrib/mue/FactorMuE.py @@ -23,8 +23,8 @@ Reference: [1] E. N. Weinstein, D. S. Marks (2021) -"Generative probabilistic biological sequence models that account for -mutational variability" +"A structured observation distribution for generative biological sequence +prediction and forecasting" https://www.biorxiv.org/content/10.1101/2020.07.31.231381v2.full.pdf """ @@ -62,10 +62,10 @@ def generate_data(small_test, include_stop, device): def main(args): # Load dataset. - if args.cpu_data and args.cuda: + if args.cpu_data or not args.cuda: device = torch.device("cpu") else: - device = None + device = torch.device("cuda") if args.test: dataset = generate_data(args.small, args.include_stop, device) else: @@ -84,7 +84,7 @@ def main(args): # Specific data split seed, for comparability across models and # parameter initializations. pyro.set_rng_seed(args.rng_data_seed) - indices = torch.randperm(sum(data_lengths)).tolist() + indices = torch.randperm(sum(data_lengths), device=device).tolist() dataset_train, dataset_test = [ torch.utils.data.Subset(dataset, indices[(offset - length) : offset]) for offset, length in zip( @@ -131,7 +131,12 @@ def main(args): ) n_epochs = args.n_epochs losses = model.fit_svi( - dataset_train, n_epochs, args.anneal, args.batch_size, scheduler, args.jit + dataset_train, + n_epochs, + args.anneal, + args.batch_size, + scheduler, + args.jit, ) # Evaluate. @@ -233,13 +238,18 @@ def main(args): ) with open( os.path.join( - args.out_folder, "FactorMuE_results.input_{}.txt".format(time_stamp) + args.out_folder, + "FactorMuE_results.input_{}.txt".format(time_stamp), ), "w", ) as ow: ow.write("[args]\n") + args.latent_seq_length = model.latent_seq_length + args.latent_alphabet = model.latent_alphabet_length for elem in list(args.__dict__.keys()): ow.write("{} = {}\n".format(elem, args.__getattribute__(elem))) + ow.write("alphabet_str = {}\n".format("".join(dataset.alphabet))) + ow.write("max_length = {}\n".format(dataset.max_length)) if __name__ == "__main__": diff --git a/examples/contrib/mue/ProfileHMM.py b/examples/contrib/mue/ProfileHMM.py index ef1a4ef3ad..a7e09353ce 100644 --- a/examples/contrib/mue/ProfileHMM.py +++ b/examples/contrib/mue/ProfileHMM.py @@ -27,8 +27,8 @@ Cambridge university press [2] E. N. Weinstein, D. S. Marks (2021) -"Generative probabilistic biological sequence models that account for -mutational variability" +"A structured observation distribution for generative biological sequence +prediction and forecasting" https://www.biorxiv.org/content/10.1101/2020.07.31.231381v2.full.pdf """ @@ -68,10 +68,10 @@ def main(args): pyro.set_rng_seed(args.rng_seed) # Load dataset. - if args.cpu_data and args.cuda: + if args.cpu_data or not args.cuda: device = torch.device("cpu") else: - device = None + device = torch.device("cuda") if args.test: dataset = generate_data(args.small, args.include_stop, device) else: @@ -90,7 +90,7 @@ def main(args): # Specific data split seed, for comparability across models and # parameter initializations. pyro.set_rng_seed(args.rng_data_seed) - indices = torch.randperm(sum(data_lengths)).tolist() + indices = torch.randperm(sum(data_lengths), device=device).tolist() dataset_train, dataset_test = [ torch.utils.data.Subset(dataset, indices[(offset - length) : offset]) for offset, length in zip( @@ -200,8 +200,11 @@ def main(args): "w", ) as ow: ow.write("[args]\n") + args.latent_seq_length = model.latent_seq_length for elem in list(args.__dict__.keys()): ow.write("{} = {}\n".format(elem, args.__getattribute__(elem))) + ow.write("alphabet_str = {}\n".format("".join(dataset.alphabet))) + ow.write("max_length = {}\n".format(dataset.max_length)) if __name__ == "__main__": diff --git a/pyro/contrib/mue/dataloaders.py b/pyro/contrib/mue/dataloaders.py index 4df5444739..82370b2694 100644 --- a/pyro/contrib/mue/dataloaders.py +++ b/pyro/contrib/mue/dataloaders.py @@ -141,3 +141,60 @@ def __len__(self): def __getitem__(self, ind): return (self.seq_data[ind], self.L_data[ind]) + + +def write(x, alphabet, file, truncate_stop=False, append=False, scores=None): + """ + Write sequence samples to file. + + :param ~torch.Tensor x: One-hot encoded sequences, with size + ``(data_size, seq_length, alphabet_length)``. May be padded with + zeros for variable length sequences. + :param ~np.array alphabet: Alphabet. + :param str file: Output file, where sequences will be written + in fasta format. + :param bool truncate_stop: If True, sequences will be truncated at the + first stop symbol (i.e. the stop symbol and everything after will not + be written). If False, the whole sequence will be written, including + any internal stop symbols. + :param bool append: If True, sequences are appended to the end of the + output file. If False, the file is first erased. + """ + print_alphabet = np.array(list(alphabet) + [""]) + x = torch.cat([x, torch.zeros(list(x.shape[:2]) + [1])], -1) + if truncate_stop: + mask = ( + torch.cumsum( + torch.matmul( + x, torch.tensor(print_alphabet == "*", dtype=torch.double) + ), + -1, + ) + > 0 + ).to(torch.double) + x = x * (1 - mask).unsqueeze(-1) + x[:, :, -1] = mask + else: + x[:, :, -1] = (torch.sum(x, -1) < 0.5).to(torch.double) + index = ( + torch.matmul(x, torch.arange(x.shape[-1], dtype=torch.double)) + .to(torch.long) + .cpu() + .numpy() + ) + if scores is None: + seqs = [ + ">{}\n".format(j) + "".join(elem) + "\n" + for j, elem in enumerate(print_alphabet[index]) + ] + else: + seqs = [ + ">{}\n".format(j) + "".join(elem) + "\n" + for j, elem in zip(scores, print_alphabet[index]) + ] + if append: + open_flag = "a" + else: + open_flag = "w" + with open(file, open_flag) as fw: + fw.write("".join(seqs)) diff --git a/pyro/contrib/mue/missingdatahmm.py b/pyro/contrib/mue/missingdatahmm.py index 26084c10e7..6c04d47f9e 100644 --- a/pyro/contrib/mue/missingdatahmm.py +++ b/pyro/contrib/mue/missingdatahmm.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import torch +from torch.distributions import Categorical, OneHotCategorical from pyro.distributions import constraints from pyro.distributions.hmm import _sequential_logmatmulexp @@ -110,3 +111,207 @@ def log_prob(self, value): # Marginalize out final state. result = result.logsumexp(-1) return result + + def sample(self, sample_shape=torch.Size([])): + """ + :param ~torch.Size sample_shape: Sample shape, last dimension must be + ``num_steps`` and must be broadcastable to + ``(batch_size, num_steps)``. batch_size must be int not tuple. + """ + # shape: batch_size x num_steps x categorical_size + shape = broadcast_shape( + torch.Size(list(self.batch_shape) + [1, 1]), + torch.Size(list(sample_shape) + [1]), + torch.Size((1, 1, self.event_shape[-1])), + ) + # state: batch_size x state_dim + state = OneHotCategorical(logits=self.initial_logits).sample() + # sample: batch_size x num_steps x categorical_size + sample = torch.zeros(shape) + for i in range(shape[-2]): + # batch_size x 1 x state_dim @ + # batch_size x state_dim x categorical_size + obs_logits = torch.matmul( + state.unsqueeze(-2), self.observation_logits + ).squeeze(-2) + sample[:, i, :] = OneHotCategorical(logits=obs_logits).sample() + # batch_size x 1 x state_dim @ + # batch_size x state_dim x state_dim + trans_logits = torch.matmul( + state.unsqueeze(-2), self.transition_logits + ).squeeze(-2) + state = OneHotCategorical(logits=trans_logits).sample() + + return sample + + def filter(self, value): + """ + Compute the marginal probability of the state variable at each + step conditional on the previous observations. + + :param ~torch.Tensor value: One-hot encoded observation. + Must be real-valued (float) and broadcastable to + ``(batch_size, num_steps, categorical_size)`` where + ``categorical_size`` is the dimension of the categorical output. + """ + # batch_size x num_steps x state_dim + shape = broadcast_shape( + torch.Size(list(self.batch_shape) + [1, 1]), + torch.Size(list(value.shape[:-1]) + [1]), + torch.Size((1, 1, self.initial_logits.shape[-1])), + ) + filter = torch.zeros(shape) + + # Combine observation and transition factors. + # batch_size x num_steps x state_dim + value_logits = torch.matmul( + value, torch.transpose(self.observation_logits, -2, -1) + ) + # batch_size x num_steps-1 x state_dim x state_dim + result = self.transition_logits.unsqueeze(-3) + value_logits[..., 1:, None, :] + + # Forward pass. (This could be parallelized using the + # Sarkka & Garcia-Fernandez method.) + filter[..., 0, :] = self.initial_logits + value_logits[..., 0, :] + filter[..., 0, :] = filter[..., 0, :] - torch.logsumexp( + filter[..., 0, :], -1, True + ) + for i in range(1, shape[-2]): + filter[..., i, :] = torch.logsumexp( + filter[..., i - 1, :, None] + result[..., i - 1, :, :], -2 + ) + filter[..., i, :] = filter[..., i, :] - torch.logsumexp( + filter[..., i, :], -1, True + ) + return filter + + def smooth(self, value): + """ + Compute posterior expected value of state at each position (smoothing). + + :param ~torch.Tensor value: One-hot encoded observation. + Must be real-valued (float) and broadcastable to + ``(batch_size, num_steps, categorical_size)`` where + ``categorical_size`` is the dimension of the categorical output. + """ + # Compute filter and initialize. + filter = self.filter(value) + shape = filter.shape + backfilter = torch.zeros(shape) + + # Combine observation and transition factors. + # batch_size x num_steps x state_dim + value_logits = torch.matmul( + value, torch.transpose(self.observation_logits, -2, -1) + ) + # batch_size x num_steps-1 x state_dim x state_dim + result = self.transition_logits.unsqueeze(-3) + value_logits[..., 1:, None, :] + # Construct backwards filter. + for i in range(shape[-2] - 1, 0, -1): + backfilter[..., i - 1, :] = torch.logsumexp( + backfilter[..., i, None, :] + result[..., i - 1, :, :], -1 + ) + + # Compute smoothed version. + smooth = filter + backfilter + smooth = smooth - torch.logsumexp(smooth, -1, True) + return smooth + + def sample_states(self, value): + """ + Sample states with forward filtering-backward sampling algorithm. + + :param ~torch.Tensor value: One-hot encoded observation. + Must be real-valued (float) and broadcastable to + ``(batch_size, num_steps, categorical_size)`` where + ``categorical_size`` is the dimension of the categorical output. + """ + filter = self.filter(value) + shape = filter.shape + joint = filter.unsqueeze(-1) + self.transition_logits.unsqueeze(-3) + states = torch.zeros(shape[:-1], dtype=torch.long) + states[..., -1] = Categorical(logits=filter[..., -1, :]).sample() + for i in range(shape[-2] - 1, 0, -1): + logits = torch.gather( + joint[..., i - 1, :, :], + -1, + states[..., i, None, None] + * torch.ones([shape[-1], 1], dtype=torch.long), + ).squeeze(-1) + states[..., i - 1] = Categorical(logits=logits).sample() + return states + + def map_states(self, value): + """ + Compute maximum a posteriori (MAP) estimate of state variable with + Viterbi algorithm. + + :param ~torch.Tensor value: One-hot encoded observation. + Must be real-valued (float) and broadcastable to + ``(batch_size, num_steps, categorical_size)`` where + ``categorical_size`` is the dimension of the categorical output. + """ + # Setup for Viterbi. + # batch_size x num_steps x state_dim + shape = broadcast_shape( + torch.Size(list(self.batch_shape) + [1, 1]), + torch.Size(list(value.shape[:-1]) + [1]), + torch.Size((1, 1, self.initial_logits.shape[-1])), + ) + state_logits = torch.zeros(shape) + state_traceback = torch.zeros(shape, dtype=torch.long) + + # Combine observation and transition factors. + # batch_size x num_steps x state_dim + value_logits = torch.matmul( + value, torch.transpose(self.observation_logits, -2, -1) + ) + # batch_size x num_steps-1 x state_dim x state_dim + result = self.transition_logits.unsqueeze(-3) + value_logits[..., 1:, None, :] + + # Forward pass. + state_logits[..., 0, :] = self.initial_logits + value_logits[..., 0, :] + for i in range(1, shape[-2]): + transit_weights = ( + state_logits[..., i - 1, :, None] + result[..., i - 1, :, :] + ) + state_logits[..., i, :], state_traceback[..., i, :] = torch.max( + transit_weights, -2 + ) + # Traceback. + map_states = torch.zeros(shape[:-1], dtype=torch.long) + map_states[..., -1] = torch.argmax(state_logits[..., -1, :], -1) + for i in range(shape[-2] - 1, 0, -1): + map_states[..., i - 1] = torch.gather( + state_traceback[..., i, :], -1, map_states[..., i].unsqueeze(-1) + ).squeeze(-1) + return map_states + + def given_states(self, states): + """ + Distribution conditional on the state variable. + + :param ~torch.Tensor map_states: State trajectory. Must be + integer-valued (long) and broadcastable to + ``(batch_size, num_steps)``. + """ + shape = broadcast_shape( + list(self.batch_shape) + [1, 1], + list(states.shape[:-1]) + [1, 1], + [1, 1, self.observation_logits.shape[-1]], + ) + states_index = states.unsqueeze(-1) * torch.ones(shape, dtype=torch.long) + obs_logits = self.observation_logits * torch.ones(shape) + logits = torch.gather(obs_logits, -2, states_index) + return OneHotCategorical(logits=logits) + + def sample_given_states(self, states): + """ + Sample an observation conditional on the state variable. + + :param ~torch.Tensor map_states: State trajectory. Must be + integer-valued (long) and broadcastable to + ``(batch_size, num_steps)``. + """ + conditional = self.given_states(states) + return conditional.sample() diff --git a/pyro/contrib/mue/models.py b/pyro/contrib/mue/models.py index db041ea326..7cf36aa3cd 100644 --- a/pyro/contrib/mue/models.py +++ b/pyro/contrib/mue/models.py @@ -159,17 +159,26 @@ def guide(self, seq_data, local_scale): ) insert_q_sd = pyro.param("insert_q_sd", torch.zeros(self.indel_shape)) pyro.sample( - "insert", dist.Normal(insert_q_mn, softplus(insert_q_sd)).to_event(3) + "insert", + dist.Normal(insert_q_mn, softplus(insert_q_sd)).to_event(3), ) delete_q_mn = pyro.param( "delete_q_mn", torch.ones(self.indel_shape) * self.indel_prior ) delete_q_sd = pyro.param("delete_q_sd", torch.zeros(self.indel_shape)) pyro.sample( - "delete", dist.Normal(delete_q_mn, softplus(delete_q_sd)).to_event(3) + "delete", + dist.Normal(delete_q_mn, softplus(delete_q_sd)).to_event(3), ) - def fit_svi(self, dataset, epochs=2, batch_size=1, scheduler=None, jit=False): + def fit_svi( + self, + dataset, + epochs=2, + batch_size=1, + scheduler=None, + jit=False, + ): """ Infer approximate posterior with stochastic variational inference. @@ -196,10 +205,18 @@ def fit_svi(self, dataset, epochs=2, batch_size=1, scheduler=None, jit=False): "gamma": 0.5, } ) + if self.is_cuda: + device = torch.device("cuda") + else: + device = torch.device("cpu") # Initialize guide. self.guide(None, None) dataload = DataLoader( - dataset, batch_size=batch_size, shuffle=True, pin_memory=self.pin_memory + dataset, + batch_size=batch_size, + shuffle=True, + pin_memory=self.pin_memory, + generator=torch.Generator(device=device), ) # Setup stochastic variational inference. if jit: @@ -703,8 +720,16 @@ def fit_svi( "gamma": 0.5, } ) + if self.is_cuda: + device = torch.device("cuda") + else: + device = torch.device("cpu") dataload = DataLoader( - dataset, batch_size=batch_size, shuffle=True, pin_memory=self.pin_memory + dataset, + batch_size=batch_size, + shuffle=True, + pin_memory=self.pin_memory, + generator=torch.Generator(device=device), ) # Initialize guide. for seq_data, L_data in dataload: diff --git a/tests/contrib/mue/test_dataloaders.py b/tests/contrib/mue/test_dataloaders.py index e1e570c72c..00d7f4e6f0 100644 --- a/tests/contrib/mue/test_dataloaders.py +++ b/tests/contrib/mue/test_dataloaders.py @@ -4,7 +4,7 @@ import pytest import torch -from pyro.contrib.mue.dataloaders import BiosequenceDataset, alphabets +from pyro.contrib.mue.dataloaders import BiosequenceDataset, alphabets, write @pytest.mark.parametrize("source_type", ["list", "fasta"]) @@ -70,3 +70,49 @@ def test_biosequencedataset(source_type, alphabet, include_stop): dataload = torch.utils.data.DataLoader(dataset, batch_size=2) for seq_data, L_data in dataload: assert seq_data.shape[0] == L_data.shape[0] + + +def test_write(): + + # Define dataset. + seqs = ["AATC*C", "CA*", "T**"] + dataset = BiosequenceDataset(seqs, "list", "ACGT*", include_stop=False) + # With truncation at stop symbol. + # Write. + with open("test_seqs.fasta", "w") as fw: + fw.write("") + write( + dataset.seq_data, + dataset.alphabet, + "test_seqs.fasta", + truncate_stop=True, + append=True, + ) + + # Reload. + dataset2 = BiosequenceDataset("test_seqs.fasta", "fasta", "dna", include_stop=True) + to_stop_lens = [4, 2, 1] + for j, to_stop_len in enumerate(to_stop_lens): + assert torch.allclose( + dataset.seq_data[j, :to_stop_len], dataset2.seq_data[j, :to_stop_len] + ) + assert torch.allclose( + dataset2.seq_data[j, (to_stop_len + 1) :], torch.tensor(0.0) + ) + + # Without truncation at stop symbol. + # Write. + write( + dataset.seq_data, + dataset.alphabet, + "test_seqs.fasta", + truncate_stop=False, + append=False, + ) + + # Reload. + dataset2 = BiosequenceDataset( + "test_seqs.fasta", "fasta", "ACGT*", include_stop=False + ) + for j, to_stop_len in enumerate(to_stop_lens): + assert torch.allclose(dataset.seq_data, dataset2.seq_data) diff --git a/tests/contrib/mue/test_missingdatahmm.py b/tests/contrib/mue/test_missingdatahmm.py index ba7a38e6e7..1224e7b872 100644 --- a/tests/contrib/mue/test_missingdatahmm.py +++ b/tests/contrib/mue/test_missingdatahmm.py @@ -26,9 +26,9 @@ def test_hmm_log_prob(): f = torch.matmul(f, a) * e[:, 1] f = torch.matmul(f, a) * e[:, 1] f = torch.matmul(f, a) * e[:, 0] - chk_lp = torch.log(torch.sum(f)) + expected_lp = torch.log(torch.sum(f)) - assert torch.allclose(lp, chk_lp) + assert torch.allclose(lp, expected_lp) # Batch values. x = torch.cat( @@ -45,9 +45,9 @@ def test_hmm_log_prob(): f = a0 * e[:, 0] f = torch.matmul(f, a) * e[:, 0] f = torch.matmul(f, a) * e[:, 0] - chk_lp = torch.cat([chk_lp[None], torch.log(torch.sum(f))[None]]) + expected_lp = torch.cat([expected_lp[None], torch.log(torch.sum(f))[None]]) - assert torch.allclose(lp, chk_lp) + assert torch.allclose(lp, expected_lp) # Batch both parameters and values. a0 = torch.cat([a0[None, :], torch.tensor([0.2, 0.7, 0.1])[None, :]]) @@ -73,9 +73,9 @@ def test_hmm_log_prob(): f = a0[1, :] * e[1, :, 0] f = torch.matmul(f, a[1, :, :]) * e[1, :, 0] f = torch.matmul(f, a[1, :, :]) * e[1, :, 0] - chk_lp = torch.cat([chk_lp[0][None], torch.log(torch.sum(f))[None]]) + expected_lp = torch.cat([expected_lp[0][None], torch.log(torch.sum(f))[None]]) - assert torch.allclose(lp, chk_lp) + assert torch.allclose(lp, expected_lp) @pytest.mark.parametrize("batch_initial", [False, True]) @@ -178,3 +178,385 @@ def test_DiscreteHMM_comparison( assert lp_vldhmm.shape == (batch_size,) # Values. assert torch.allclose(lp_vldhmm, lp_dhmm) + # Filter. + filter_dhmm = dhmm.filter(value) + filter_vldhmm = vldhmm.filter(value_oh) + assert torch.allclose(filter_dhmm.logits, filter_vldhmm[..., -1, :]) + # Check other computations run. + vldhmm.sample(value_oh.shape[:-1]) + vldhmm.smooth(value_oh) + vldhmm.sample_states(value_oh) + map_states = vldhmm.map_states(value_oh) + print(value_oh.shape, map_states.shape) + vldhmm.sample_given_states(map_states) + + +@pytest.mark.parametrize("batch_data", [False, True]) +def test_samples(batch_data): + initial_logits = torch.tensor([-100, 0, -100, -100], dtype=torch.float64) + transition_logits = torch.tensor( + [ + [-100, -100, 0, -100], + [-100, -100, -100, 0], + [0, -100, -100, -100], + [-100, 0, -100, -100], + ], + dtype=torch.float64, + ) + obs_logits = torch.tensor( + [[0, -100, -100], [-100, 0, -100], [-100, -100, 0], [-100, -100, 0]], + dtype=torch.float64, + ) + if batch_data: + initial_logits = torch.tensor( + [[-100, 0, -100, -100], [0, -100, -100, -100]], dtype=torch.float64 + ) + transition_logits = transition_logits * torch.ones( + [2] + list(transition_logits.shape) + ) + obs_logits = obs_logits * torch.ones([2] + list(obs_logits.shape)) + + model = MissingDataDiscreteHMM(initial_logits, transition_logits, obs_logits) + + if not batch_data: + sample = model.sample(torch.Size([3])) + print(sample) + assert torch.allclose( + sample, torch.tensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [0.0, 1.0, 0.0]]) + ) + else: + sample = model.sample(torch.Size([2, 3])) + print(sample[0, :, :]) + assert torch.allclose( + sample[0, :, :], + torch.tensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [0.0, 1.0, 0.0]]), + ) + print(sample[1, :, :]) + assert torch.allclose( + sample[1, :, :], + torch.tensor([[1.0, 0.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]), + ) + + +def indiv_filter(a0, a, e, x): + + alph = torch.zeros((x.shape[0], a0.shape[0])) + for j in range(a0.shape[0]): + vec = a0[j] + if torch.sum(x[0, :]) > 0.5: + vec = vec * torch.dot(x[0, :], e[j, :]) + alph[0, j] = vec + alph[0, :] = alph[0, :] / torch.sum(alph[0, :]) + for t in range(1, x.shape[0]): + for j in range(a0.shape[0]): + vec = torch.sum(alph[t - 1, :] * a[:, j]) + if torch.sum(x[t, :]) > 0.5: + vec = vec * torch.dot(x[t, :], e[j, :]) + alph[t, j] = vec + alph[t, :] = alph[t, :] / torch.sum(alph[t, :]) + return torch.log(alph) + + +def indiv_smooth(a0, a, e, x): + + alph = indiv_filter(a0, a, e, x) + beta = torch.zeros(alph.shape) + beta[-1, :] = 1.0 + for t in range(alph.shape[0] - 1, 0, -1): + for i in range(a0.shape[0]): + for j in range(a0.shape[0]): + vec = beta[t, j] * a[i, j] + if torch.sum(x[t, :]) > 0.5: + vec = vec * torch.dot(x[t, :], e[j, :]) + beta[t - 1, i] += vec + smooth = torch.exp(alph) * beta + smooth = smooth / torch.sum(smooth, -1, True) + return torch.log(smooth) + + +def indiv_map_states(a0, a, e, x): + # Viterbi algorithm, implemented without batching or vector operations. + + delta = torch.zeros((x.shape[0], a0.shape[0])) + for j in range(a0.shape[0]): + vec = a0[j] + if torch.sum(x[0, :]) > 0.5: + vec = vec * torch.dot(x[0, :], e[j, :]) + delta[0, j] = vec + traceback = torch.zeros((x.shape[0], a0.shape[0]), dtype=torch.long) + for t in range(1, x.shape[0]): + for j in range(a0.shape[0]): + vec = delta[t - 1, :] * a[:, j] + if torch.sum(x[t, :]) > 0.5: + vec = vec * torch.dot(x[t, :], e[j, :]) + delta[t, j] = torch.max(vec) + traceback[t, j] = torch.argmax(vec) + expected_map_states = torch.zeros(x.shape[0], dtype=torch.long) + expected_map_states[-1] = torch.argmax(delta[-1, :]) + for t in range(x.shape[0] - 1, 0, -1): + expected_map_states[t - 1] = traceback[t, expected_map_states[t]] + + return expected_map_states + + +def test_state_infer(): + + # HMM parameters. + a0 = torch.tensor([0.9, 0.08, 0.02]) + a = torch.tensor([[0.1, 0.8, 0.1], [0.5, 0.3, 0.2], [0.4, 0.4, 0.2]]) + e = torch.tensor([[0.9, 0.1], [0.1, 0.9], [0.5, 0.5]]) + # Observed value. + x = torch.tensor( + [[0.0, 1.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0], [1.0, 0.0], [0.0, 0.0]] + ) + + expected_map_states = indiv_map_states(a0, a, e, x) + expected_filter = indiv_filter(a0, a, e, x) + expected_smooth = indiv_smooth(a0, a, e, x) + + hmm_distr = MissingDataDiscreteHMM(torch.log(a0), torch.log(a), torch.log(e)) + map_states = hmm_distr.map_states(x) + filter = hmm_distr.filter(x) + smooth = hmm_distr.smooth(x) + + assert torch.allclose(map_states, expected_map_states) + assert torch.allclose(filter, expected_filter) + assert torch.allclose(smooth, expected_smooth) + + # Batch values. + x = torch.cat( + [ + x[None, :, :], + torch.tensor( + [[1.0, 0.0], [1.0, 0.0], [1.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]] + )[None, :, :], + ], + dim=0, + ) + map_states = hmm_distr.map_states(x) + filter = hmm_distr.filter(x) + smooth = hmm_distr.smooth(x) + + expected_map_states = torch.cat( + [ + indiv_map_states(a0, a, e, x[0])[None, :], + indiv_map_states(a0, a, e, x[1])[None, :], + ], + -2, + ) + expected_filter = torch.cat( + [ + indiv_filter(a0, a, e, x[0])[None, :, :], + indiv_filter(a0, a, e, x[1])[None, :, :], + ], + -3, + ) + expected_smooth = torch.cat( + [ + indiv_smooth(a0, a, e, x[0])[None, :, :], + indiv_smooth(a0, a, e, x[1])[None, :, :], + ], + -3, + ) + + assert torch.allclose(map_states, expected_map_states) + assert torch.allclose(filter, expected_filter) + assert torch.allclose(smooth, expected_smooth) + + # Batch parameters. + a0 = torch.cat([a0[None, :], torch.tensor([0.2, 0.7, 0.1])[None, :]]) + a = torch.cat( + [ + a[None, :, :], + torch.tensor([[0.8, 0.1, 0.1], [0.2, 0.6, 0.2], [0.1, 0.1, 0.8]])[ + None, :, : + ], + ], + dim=0, + ) + e = torch.cat( + [ + e[None, :, :], + torch.tensor([[0.4, 0.6], [0.99, 0.01], [0.7, 0.3]])[None, :, :], + ], + dim=0, + ) + hmm_distr = MissingDataDiscreteHMM(torch.log(a0), torch.log(a), torch.log(e)) + map_states = hmm_distr.map_states(x[1]) + filter = hmm_distr.filter(x[1]) + smooth = hmm_distr.smooth(x[1]) + + expected_map_states = torch.cat( + [ + indiv_map_states(a0[0], a[0], e[0], x[1])[None, :], + indiv_map_states(a0[1], a[1], e[1], x[1])[None, :], + ], + -2, + ) + expected_filter = torch.cat( + [ + indiv_filter(a0[0], a[0], e[0], x[1])[None, :, :], + indiv_filter(a0[1], a[1], e[1], x[1])[None, :, :], + ], + -3, + ) + expected_smooth = torch.cat( + [ + indiv_smooth(a0[0], a[0], e[0], x[1])[None, :, :], + indiv_smooth(a0[1], a[1], e[1], x[1])[None, :, :], + ], + -3, + ) + + assert torch.allclose(map_states, expected_map_states) + assert torch.allclose(filter, expected_filter) + assert torch.allclose(smooth, expected_smooth) + + # Batch both parameters and values. + map_states = hmm_distr.map_states(x) + filter = hmm_distr.filter(x) + smooth = hmm_distr.smooth(x) + + expected_map_states = torch.cat( + [ + indiv_map_states(a0[0], a[0], e[0], x[0])[None, :], + indiv_map_states(a0[1], a[1], e[1], x[1])[None, :], + ], + -2, + ) + expected_filter = torch.cat( + [ + indiv_filter(a0[0], a[0], e[0], x[0])[None, :, :], + indiv_filter(a0[1], a[1], e[1], x[1])[None, :, :], + ], + -3, + ) + expected_smooth = torch.cat( + [ + indiv_smooth(a0[0], a[0], e[0], x[0])[None, :, :], + indiv_smooth(a0[1], a[1], e[1], x[1])[None, :, :], + ], + -3, + ) + + assert torch.allclose(map_states, expected_map_states) + assert torch.allclose(filter, expected_filter) + assert torch.allclose(smooth, expected_smooth) + + +def test_sample_given_states(): + a0 = torch.tensor([0.9, 0.08, 0.02]) + a = torch.tensor([[0.1, 0.8, 0.1], [0.5, 0.3, 0.2], [0.4, 0.4, 0.2]]) + eps = 1e-10 + # Effectively deterministic to check sampler. + e = torch.tensor([[1 - eps, eps], [eps, 1 - eps], [eps, 1 - eps]]) + + map_states = torch.tensor([0, 2, 1, 0], dtype=torch.long) + hmm_distr = MissingDataDiscreteHMM(torch.log(a0), torch.log(a), torch.log(e)) + sample = hmm_distr.sample_given_states(map_states) + expected_sample = torch.tensor([[1.0, 0.0], [0.0, 1.0], [0.0, 1.0], [1.0, 0.0]]) + assert torch.allclose(sample, expected_sample) + + # Batch values + map_states = torch.tensor([[0, 2, 1, 0], [0, 0, 0, 1]], dtype=torch.long) + sample = hmm_distr.sample_given_states(map_states) + expected_sample = torch.tensor( + [ + [[1.0, 0.0], [0.0, 1.0], [0.0, 1.0], [1.0, 0.0]], + [[1.0, 0.0], [1.0, 0.0], [1.0, 0.0], [0.0, 1.0]], + ] + ) + assert torch.allclose(sample, expected_sample) + + # Batch parameters + e = torch.cat( + [ + e[None, :, :], + torch.tensor([[eps, 1 - eps], [eps, 1 - eps], [1 - eps, eps]])[None, :, :], + ], + dim=0, + ) + hmm_distr = MissingDataDiscreteHMM(torch.log(a0), torch.log(a), torch.log(e)) + sample = hmm_distr.sample_given_states(map_states[0]) + expected_sample = torch.tensor( + [ + [[1.0, 0.0], [0.0, 1.0], [0.0, 1.0], [1.0, 0.0]], + [[0.0, 1.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0]], + ] + ) + assert torch.allclose(sample, expected_sample) + + # Batch parameters and values. + sample = hmm_distr.sample_given_states(map_states) + expected_sample = torch.tensor( + [ + [[1.0, 0.0], [0.0, 1.0], [0.0, 1.0], [1.0, 0.0]], + [[0.0, 1.0], [0.0, 1.0], [0.0, 1.0], [0.0, 1.0]], + ] + ) + assert torch.allclose(sample, expected_sample) + + +def test_sample_states(): + + # Effectively deterministic to check sampler. + eps = 1e-10 + a0 = torch.tensor([1 - eps, eps / 2, eps / 2]) + a = torch.tensor( + [ + [eps / 2, 1 - eps, eps / 2], + [eps, 0.5 - eps / 2, 0.5 - eps / 2], + [eps, 0.5 - eps / 2, 0.5 - eps / 2], + ] + ) + e = torch.tensor([[1 - eps, eps], [1 - eps, eps], [eps, 1 - eps]]) + x = torch.tensor([[1.0, 0.0], [0.0, 0.0], [0.0, 1.0], [0.0, 1.0]]) + + hmm_distr = MissingDataDiscreteHMM(torch.log(a0), torch.log(a), torch.log(e)) + states = hmm_distr.sample_states(x) + expected_states = torch.tensor([0, 1, 2, 2]) + assert torch.allclose(states, expected_states) + + # Batch values. + x = torch.cat( + [ + x[None, :, :], + torch.tensor([[1.0, 0.0], [0.0, 0.0], [0.0, 1.0], [1.0, 0.0]])[None, :, :], + ], + dim=0, + ) + states = hmm_distr.sample_states(x) + expected_states = torch.tensor([[0, 1, 2, 2], [0, 1, 2, 1]]) + assert torch.allclose(states, expected_states) + + # Batch parameters + a0 = torch.cat([a0[None, :], torch.tensor([eps / 2, 1 - eps, eps / 2])[None, :]]) + a = torch.cat( + [ + a[None, :, :], + torch.tensor( + [ + [eps / 2, 1 - eps, eps / 2], + [eps / 2, 1 - eps, eps / 2], + [eps / 2, 1 - eps, eps / 2], + ] + )[None, :, :], + ], + dim=0, + ) + e = torch.cat( + [ + e[None, :, :], + torch.tensor([[1 - eps, eps], [0.5, 0.5], [eps, 1 - eps]])[None, :, :], + ], + dim=0, + ) + hmm_distr = MissingDataDiscreteHMM(torch.log(a0), torch.log(a), torch.log(e)) + states = hmm_distr.sample_states(x[1]) + expected_states = torch.tensor([[0, 1, 2, 1], [1, 1, 1, 1]]) + assert torch.allclose(states, expected_states) + + # Batch both parameters and values. + states = hmm_distr.sample_states(x) + expected_states = torch.tensor([[0, 1, 2, 2], [1, 1, 1, 1]]) + assert torch.allclose(states, expected_states) diff --git a/tests/contrib/mue/test_statearrangers.py b/tests/contrib/mue/test_statearrangers.py index c17bdcd50a..9df91e819e 100644 --- a/tests/contrib/mue/test_statearrangers.py +++ b/tests/contrib/mue/test_statearrangers.py @@ -81,9 +81,9 @@ def test_profile_alternate_imp(M, batch_size, substitute): c = c[None, :, :] * torch.ones([batch_size, 1, 1]) if substitute: ll = ll.unsqueeze(0) - chk_a = torch.zeros((batch_dim_size, K, K)) - chk_a0 = torch.zeros((batch_dim_size, K)) - chk_e = torch.zeros((batch_dim_size, K, 4)) + expected_a = torch.zeros((batch_dim_size, K, K)) + expected_a0 = torch.zeros((batch_dim_size, K)) + expected_e = torch.zeros((batch_dim_size, K, 4)) for b in range(batch_dim_size): m, g = -1, 0 u1[b][-1] = 1e-32 @@ -91,11 +91,11 @@ def test_profile_alternate_imp(M, batch_size, substitute): for mp in range(M + gp): kp = mg2k(mp, gp, M) if m + 1 - g == mp and gp == 0: - chk_a0[b, kp] = (1 - r1[b, m + 1 - g, g]) * ( + expected_a0[b, kp] = (1 - r1[b, m + 1 - g, g]) * ( 1 - u1[b, m + 1 - g, g] ) elif m + 1 - g < mp and gp == 0: - chk_a0[b, kp] = ( + expected_a0[b, kp] = ( (1 - r1[b, m + 1 - g, g]) * u1[b, m + 1 - g, g] * simpleprod( @@ -108,9 +108,9 @@ def test_profile_alternate_imp(M, batch_size, substitute): * (1 - u1[b, mp, 2]) ) elif m + 1 - g == mp and gp == 1: - chk_a0[b, kp] = r1[b, m + 1 - g, g] + expected_a0[b, kp] = r1[b, m + 1 - g, g] elif m + 1 - g < mp and gp == 1: - chk_a0[b, kp] = ( + expected_a0[b, kp] = ( (1 - r1[b, m + 1 - g, g]) * u1[b, m + 1 - g, g] * simpleprod( @@ -128,11 +128,11 @@ def test_profile_alternate_imp(M, batch_size, substitute): for mp in range(M + gp): kp = mg2k(mp, gp, M) if m + 1 - g == mp and gp == 0: - chk_a[b, k, kp] = (1 - r1[b, m + 1 - g, g]) * ( + expected_a[b, k, kp] = (1 - r1[b, m + 1 - g, g]) * ( 1 - u1[b, m + 1 - g, g] ) elif m + 1 - g < mp and gp == 0: - chk_a[b, k, kp] = ( + expected_a[b, k, kp] = ( (1 - r1[b, m + 1 - g, g]) * u1[b, m + 1 - g, g] * simpleprod( @@ -145,9 +145,9 @@ def test_profile_alternate_imp(M, batch_size, substitute): * (1 - u1[b, mp, 2]) ) elif m + 1 - g == mp and gp == 1: - chk_a[b, k, kp] = r1[b, m + 1 - g, g] + expected_a[b, k, kp] = r1[b, m + 1 - g, g] elif m + 1 - g < mp and gp == 1: - chk_a[b, k, kp] = ( + expected_a[b, k, kp] = ( (1 - r1[b, m + 1 - g, g]) * u1[b, m + 1 - g, g] * simpleprod( @@ -159,23 +159,23 @@ def test_profile_alternate_imp(M, batch_size, substitute): * r1[b, mp, 2] ) elif m == M and mp == M and g == 0 and gp == 0: - chk_a[b, k, kp] = 1.0 + expected_a[b, k, kp] = 1.0 for g in range(2): for m in range(M + g): k = mg2k(m, g, M) if g == 0: - chk_e[b, k, :] = s[b, m, :] + expected_e[b, k, :] = s[b, m, :] else: - chk_e[b, k, :] = c[b, m, :] + expected_e[b, k, :] = c[b, m, :] if substitute: - chk_e = torch.matmul(chk_e, ll) + expected_e = torch.matmul(expected_e, ll) # --- Check --- if batch_size is None: - chk_a = chk_a.squeeze() - chk_a0 = chk_a0.squeeze() - chk_e = chk_e.squeeze() + expected_a = expected_a.squeeze() + expected_a0 = expected_a0.squeeze() + expected_e = expected_e.squeeze() assert torch.allclose( torch.sum(torch.exp(a0ln)), torch.tensor(1.0), atol=1e-3, rtol=1e-3 @@ -186,9 +186,9 @@ def test_profile_alternate_imp(M, batch_size, substitute): atol=1e-3, rtol=1e-3, ) - assert torch.allclose(chk_a0, torch.exp(a0ln)) - assert torch.allclose(chk_a, torch.exp(aln)) - assert torch.allclose(chk_e, torch.exp(eln)) + assert torch.allclose(expected_a0, torch.exp(a0ln)) + assert torch.allclose(expected_a, torch.exp(aln)) + assert torch.allclose(expected_e, torch.exp(eln)) @pytest.mark.parametrize("batch_ancestor_seq", [False, True])