Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sampling functions for the MuE/missing data discrete HMM. #2898

Merged
merged 25 commits into from
Jul 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
3b5f55a
Basic sampling function for missingdata hmm
EWeinstein Mar 24, 2021
8a27ef5
Merge remote-tracking branch 'upstream/dev' into mue-sample
EWeinstein Jun 23, 2021
fbd4c23
Viterbi algorithm for obtaining MAP state path estimator.
EWeinstein Jun 24, 2021
96ea395
Sample conditional on latent state
EWeinstein Jun 24, 2021
8ed6bef
Filtering, smoothing, forward-backward sampling.
EWeinstein Jun 25, 2021
c4ee112
Tool for writing one-hot encoded sequence samples to fasta files.
EWeinstein Jun 25, 2021
d39314d
Updated reference.
EWeinstein Jun 25, 2021
421aad5
Fix broadcasting in sample.
EWeinstein Jun 29, 2021
bd5828b
Save additional information.
EWeinstein Jun 29, 2021
7a18180
Adjust device for new torch.randperm cuda behavior.
EWeinstein Jul 3, 2021
0320591
Fix keyword input.
EWeinstein Jul 3, 2021
a16725d
Another fix for the torch.randperm update
EWeinstein Jul 3, 2021
bb4627b
Fix device call in dataloader.
EWeinstein Jul 3, 2021
2fac7f2
Adjust device passing for randperm.
EWeinstein Jul 3, 2021
501192f
Move to cpu to write.
EWeinstein Jul 4, 2021
467abd6
Break out conditional distribution calculation.
EWeinstein Jul 5, 2021
e2d73f7
Merge remote-tracking branch 'upstream/dev' into mue-sample
EWeinstein Jul 7, 2021
5155f3b
Formatting.
EWeinstein Jul 7, 2021
ce10e15
Rename get states functions and test variables.
EWeinstein Jul 22, 2021
2a82730
Debug --cuda --cpu-data error in examples.
EWeinstein Jul 22, 2021
75e5948
Still debugging generators.
EWeinstein Jul 22, 2021
8ffcd80
Debug generator + cuda.
EWeinstein Jul 22, 2021
964550e
Debug generator + cuda
EWeinstein Jul 22, 2021
b61e3b0
Adjust generator and cuda handling in both example models.
EWeinstein Jul 22, 2021
bcbb3a3
Reformat.
EWeinstein Jul 22, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions examples/contrib/mue/FactorMuE.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@

Reference:
[1] E. N. Weinstein, D. S. Marks (2021)
"Generative probabilistic biological sequence models that account for
mutational variability"
"A structured observation distribution for generative biological sequence
prediction and forecasting"
https://www.biorxiv.org/content/10.1101/2020.07.31.231381v2.full.pdf
"""

Expand Down Expand Up @@ -62,10 +62,10 @@ def generate_data(small_test, include_stop, device):
def main(args):

# Load dataset.
if args.cpu_data and args.cuda:
if args.cpu_data or not args.cuda:
device = torch.device("cpu")
else:
device = None
device = torch.device("cuda")
if args.test:
dataset = generate_data(args.small, args.include_stop, device)
else:
Expand All @@ -84,7 +84,7 @@ def main(args):
# Specific data split seed, for comparability across models and
# parameter initializations.
pyro.set_rng_seed(args.rng_data_seed)
indices = torch.randperm(sum(data_lengths)).tolist()
indices = torch.randperm(sum(data_lengths), device=device).tolist()
dataset_train, dataset_test = [
torch.utils.data.Subset(dataset, indices[(offset - length) : offset])
for offset, length in zip(
Expand Down Expand Up @@ -131,7 +131,12 @@ def main(args):
)
n_epochs = args.n_epochs
losses = model.fit_svi(
dataset_train, n_epochs, args.anneal, args.batch_size, scheduler, args.jit
dataset_train,
n_epochs,
args.anneal,
args.batch_size,
scheduler,
args.jit,
)

# Evaluate.
Expand Down Expand Up @@ -233,13 +238,18 @@ def main(args):
)
with open(
os.path.join(
args.out_folder, "FactorMuE_results.input_{}.txt".format(time_stamp)
args.out_folder,
"FactorMuE_results.input_{}.txt".format(time_stamp),
),
"w",
) as ow:
ow.write("[args]\n")
args.latent_seq_length = model.latent_seq_length
args.latent_alphabet = model.latent_alphabet_length
for elem in list(args.__dict__.keys()):
ow.write("{} = {}\n".format(elem, args.__getattribute__(elem)))
ow.write("alphabet_str = {}\n".format("".join(dataset.alphabet)))
ow.write("max_length = {}\n".format(dataset.max_length))


if __name__ == "__main__":
Expand Down
13 changes: 8 additions & 5 deletions examples/contrib/mue/ProfileHMM.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
Cambridge university press

[2] E. N. Weinstein, D. S. Marks (2021)
"Generative probabilistic biological sequence models that account for
mutational variability"
"A structured observation distribution for generative biological sequence
prediction and forecasting"
https://www.biorxiv.org/content/10.1101/2020.07.31.231381v2.full.pdf
"""

Expand Down Expand Up @@ -68,10 +68,10 @@ def main(args):
pyro.set_rng_seed(args.rng_seed)

# Load dataset.
if args.cpu_data and args.cuda:
if args.cpu_data or not args.cuda:
device = torch.device("cpu")
else:
device = None
device = torch.device("cuda")
if args.test:
dataset = generate_data(args.small, args.include_stop, device)
else:
Expand All @@ -90,7 +90,7 @@ def main(args):
# Specific data split seed, for comparability across models and
# parameter initializations.
pyro.set_rng_seed(args.rng_data_seed)
indices = torch.randperm(sum(data_lengths)).tolist()
indices = torch.randperm(sum(data_lengths), device=device).tolist()
dataset_train, dataset_test = [
torch.utils.data.Subset(dataset, indices[(offset - length) : offset])
for offset, length in zip(
Expand Down Expand Up @@ -200,8 +200,11 @@ def main(args):
"w",
) as ow:
ow.write("[args]\n")
args.latent_seq_length = model.latent_seq_length
for elem in list(args.__dict__.keys()):
ow.write("{} = {}\n".format(elem, args.__getattribute__(elem)))
ow.write("alphabet_str = {}\n".format("".join(dataset.alphabet)))
ow.write("max_length = {}\n".format(dataset.max_length))


if __name__ == "__main__":
Expand Down
57 changes: 57 additions & 0 deletions pyro/contrib/mue/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,60 @@ def __len__(self):
def __getitem__(self, ind):

return (self.seq_data[ind], self.L_data[ind])


def write(x, alphabet, file, truncate_stop=False, append=False, scores=None):
"""
Write sequence samples to file.

:param ~torch.Tensor x: One-hot encoded sequences, with size
``(data_size, seq_length, alphabet_length)``. May be padded with
zeros for variable length sequences.
:param ~np.array alphabet: Alphabet.
:param str file: Output file, where sequences will be written
in fasta format.
:param bool truncate_stop: If True, sequences will be truncated at the
first stop symbol (i.e. the stop symbol and everything after will not
be written). If False, the whole sequence will be written, including
any internal stop symbols.
:param bool append: If True, sequences are appended to the end of the
output file. If False, the file is first erased.
"""
print_alphabet = np.array(list(alphabet) + [""])
x = torch.cat([x, torch.zeros(list(x.shape[:2]) + [1])], -1)
if truncate_stop:
mask = (
torch.cumsum(
torch.matmul(
x, torch.tensor(print_alphabet == "*", dtype=torch.double)
),
-1,
)
> 0
).to(torch.double)
x = x * (1 - mask).unsqueeze(-1)
x[:, :, -1] = mask
else:
x[:, :, -1] = (torch.sum(x, -1) < 0.5).to(torch.double)
index = (
torch.matmul(x, torch.arange(x.shape[-1], dtype=torch.double))
.to(torch.long)
.cpu()
.numpy()
)
if scores is None:
seqs = [
">{}\n".format(j) + "".join(elem) + "\n"
for j, elem in enumerate(print_alphabet[index])
]
else:
seqs = [
">{}\n".format(j) + "".join(elem) + "\n"
for j, elem in zip(scores, print_alphabet[index])
]
if append:
open_flag = "a"
else:
open_flag = "w"
with open(file, open_flag) as fw:
fw.write("".join(seqs))
205 changes: 205 additions & 0 deletions pyro/contrib/mue/missingdatahmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

import torch
from torch.distributions import Categorical, OneHotCategorical

from pyro.distributions import constraints
from pyro.distributions.hmm import _sequential_logmatmulexp
Expand Down Expand Up @@ -110,3 +111,207 @@ def log_prob(self, value):
# Marginalize out final state.
result = result.logsumexp(-1)
return result

def sample(self, sample_shape=torch.Size([])):
"""
:param ~torch.Size sample_shape: Sample shape, last dimension must be
``num_steps`` and must be broadcastable to
``(batch_size, num_steps)``. batch_size must be int not tuple.
"""
# shape: batch_size x num_steps x categorical_size
shape = broadcast_shape(
torch.Size(list(self.batch_shape) + [1, 1]),
torch.Size(list(sample_shape) + [1]),
torch.Size((1, 1, self.event_shape[-1])),
)
# state: batch_size x state_dim
state = OneHotCategorical(logits=self.initial_logits).sample()
# sample: batch_size x num_steps x categorical_size
sample = torch.zeros(shape)
for i in range(shape[-2]):
# batch_size x 1 x state_dim @
# batch_size x state_dim x categorical_size
obs_logits = torch.matmul(
state.unsqueeze(-2), self.observation_logits
).squeeze(-2)
sample[:, i, :] = OneHotCategorical(logits=obs_logits).sample()
# batch_size x 1 x state_dim @
# batch_size x state_dim x state_dim
trans_logits = torch.matmul(
state.unsqueeze(-2), self.transition_logits
).squeeze(-2)
state = OneHotCategorical(logits=trans_logits).sample()

return sample

def filter(self, value):
"""
Compute the marginal probability of the state variable at each
step conditional on the previous observations.

:param ~torch.Tensor value: One-hot encoded observation.
Must be real-valued (float) and broadcastable to
``(batch_size, num_steps, categorical_size)`` where
``categorical_size`` is the dimension of the categorical output.
"""
# batch_size x num_steps x state_dim
shape = broadcast_shape(
torch.Size(list(self.batch_shape) + [1, 1]),
torch.Size(list(value.shape[:-1]) + [1]),
torch.Size((1, 1, self.initial_logits.shape[-1])),
)
filter = torch.zeros(shape)

# Combine observation and transition factors.
# batch_size x num_steps x state_dim
value_logits = torch.matmul(
value, torch.transpose(self.observation_logits, -2, -1)
)
# batch_size x num_steps-1 x state_dim x state_dim
result = self.transition_logits.unsqueeze(-3) + value_logits[..., 1:, None, :]

# Forward pass. (This could be parallelized using the
# Sarkka & Garcia-Fernandez method.)
filter[..., 0, :] = self.initial_logits + value_logits[..., 0, :]
filter[..., 0, :] = filter[..., 0, :] - torch.logsumexp(
filter[..., 0, :], -1, True
)
for i in range(1, shape[-2]):
filter[..., i, :] = torch.logsumexp(
filter[..., i - 1, :, None] + result[..., i - 1, :, :], -2
)
filter[..., i, :] = filter[..., i, :] - torch.logsumexp(
filter[..., i, :], -1, True
)
return filter

def smooth(self, value):
"""
Compute posterior expected value of state at each position (smoothing).

:param ~torch.Tensor value: One-hot encoded observation.
Must be real-valued (float) and broadcastable to
``(batch_size, num_steps, categorical_size)`` where
``categorical_size`` is the dimension of the categorical output.
"""
# Compute filter and initialize.
filter = self.filter(value)
shape = filter.shape
backfilter = torch.zeros(shape)

# Combine observation and transition factors.
# batch_size x num_steps x state_dim
value_logits = torch.matmul(
value, torch.transpose(self.observation_logits, -2, -1)
)
# batch_size x num_steps-1 x state_dim x state_dim
result = self.transition_logits.unsqueeze(-3) + value_logits[..., 1:, None, :]
# Construct backwards filter.
for i in range(shape[-2] - 1, 0, -1):
backfilter[..., i - 1, :] = torch.logsumexp(
backfilter[..., i, None, :] + result[..., i - 1, :, :], -1
)

# Compute smoothed version.
smooth = filter + backfilter
smooth = smooth - torch.logsumexp(smooth, -1, True)
return smooth

def sample_states(self, value):
"""
Sample states with forward filtering-backward sampling algorithm.

:param ~torch.Tensor value: One-hot encoded observation.
Must be real-valued (float) and broadcastable to
``(batch_size, num_steps, categorical_size)`` where
``categorical_size`` is the dimension of the categorical output.
"""
filter = self.filter(value)
shape = filter.shape
joint = filter.unsqueeze(-1) + self.transition_logits.unsqueeze(-3)
states = torch.zeros(shape[:-1], dtype=torch.long)
states[..., -1] = Categorical(logits=filter[..., -1, :]).sample()
for i in range(shape[-2] - 1, 0, -1):
logits = torch.gather(
joint[..., i - 1, :, :],
-1,
states[..., i, None, None]
* torch.ones([shape[-1], 1], dtype=torch.long),
).squeeze(-1)
states[..., i - 1] = Categorical(logits=logits).sample()
return states

def map_states(self, value):
"""
Compute maximum a posteriori (MAP) estimate of state variable with
Viterbi algorithm.

:param ~torch.Tensor value: One-hot encoded observation.
Must be real-valued (float) and broadcastable to
``(batch_size, num_steps, categorical_size)`` where
``categorical_size`` is the dimension of the categorical output.
"""
# Setup for Viterbi.
# batch_size x num_steps x state_dim
shape = broadcast_shape(
torch.Size(list(self.batch_shape) + [1, 1]),
torch.Size(list(value.shape[:-1]) + [1]),
torch.Size((1, 1, self.initial_logits.shape[-1])),
)
state_logits = torch.zeros(shape)
state_traceback = torch.zeros(shape, dtype=torch.long)

# Combine observation and transition factors.
# batch_size x num_steps x state_dim
value_logits = torch.matmul(
value, torch.transpose(self.observation_logits, -2, -1)
)
# batch_size x num_steps-1 x state_dim x state_dim
result = self.transition_logits.unsqueeze(-3) + value_logits[..., 1:, None, :]

# Forward pass.
state_logits[..., 0, :] = self.initial_logits + value_logits[..., 0, :]
for i in range(1, shape[-2]):
transit_weights = (
state_logits[..., i - 1, :, None] + result[..., i - 1, :, :]
)
state_logits[..., i, :], state_traceback[..., i, :] = torch.max(
transit_weights, -2
)
# Traceback.
map_states = torch.zeros(shape[:-1], dtype=torch.long)
map_states[..., -1] = torch.argmax(state_logits[..., -1, :], -1)
for i in range(shape[-2] - 1, 0, -1):
map_states[..., i - 1] = torch.gather(
state_traceback[..., i, :], -1, map_states[..., i].unsqueeze(-1)
).squeeze(-1)
return map_states

def given_states(self, states):
"""
Distribution conditional on the state variable.

:param ~torch.Tensor map_states: State trajectory. Must be
integer-valued (long) and broadcastable to
``(batch_size, num_steps)``.
"""
shape = broadcast_shape(
list(self.batch_shape) + [1, 1],
list(states.shape[:-1]) + [1, 1],
[1, 1, self.observation_logits.shape[-1]],
)
states_index = states.unsqueeze(-1) * torch.ones(shape, dtype=torch.long)
obs_logits = self.observation_logits * torch.ones(shape)
logits = torch.gather(obs_logits, -2, states_index)
return OneHotCategorical(logits=logits)

def sample_given_states(self, states):
"""
Sample an observation conditional on the state variable.

:param ~torch.Tensor map_states: State trajectory. Must be
integer-valued (long) and broadcastable to
``(batch_size, num_steps)``.
"""
conditional = self.given_states(states)
return conditional.sample()
Loading