Clone this repository on your local machine by running:
git clone git@github.com:Bitbol-Lab/Iterative_masking.git
and move inside the root folder. One can the use directly the functions
from the cloned repository (in the folder Iterative_masking
) or
install it with an editable install running:
pip install -e .
We recommend creating and activating a dedicated conda
or virtualenv
Python virtual environment.
In order to use the functions, the following python packages are required:
- numpy
- scipy
- numba
- fastcore
- biopython
- esm==0.4.0
- pytorch
It is also required to use a GPU (with cuda).
IM_MSA_Transformer
: Class with different functions used to generate
new MSAs with the iterative masking procedure
gen_MSAs
: example function (with parser) that can be used to generate
and save new sequences directly from the terminal.
filename = "PF00072.fasta"
filepath = "examples"
pmask = 0.1
iterations = 20
print('Tokenize')
IM_class = IM_MSA_Transformer(p_mask=pmask, filename=[filename], num=[-1], filepath=filepath)
tokenized_msa = IM_class.msa_batch_tokens
# Dictionary that maps amino acids to their token
idx_list = IM_class.idx_list
# Dictionary that maps tokens to their amino acid
aa_list = {v: k for k,v in idx_list.items()}
# Transform the tokenized MSA back into a string of amino acids
strings_msa = IM_class.untokenize_msa(tokenized_msa[:,:100,:])
import torch
# Mask all the tokens from 10 to 30 at each iteration (probability of 1) and keep the rest of the tokens unmasked
p_mask = torch.zeros(tokenized_msa.shape[-1])
p_mask[10:30] = 1.
IM_class.p_mask = p_mask
# Mask the tokens from 10 to 30 with a probability of 0.1 and keep the rest of the tokens unmasked
p_mask = torch.zeros(tokenized_msa.shape[-1])
p_mask[10:30] = 0.1
IM_class.p_mask = p_mask
# Mask all the tokens uniformly at random with a probability of 0.1
IM_class.p_mask = 0.1
- If
use_pdf
=True, generate tokens by sampling from the logits at temperatureT
. - If
save_all
=True, then the first dimension of generated_tokens is the number of iterations. - If
rand_perm
=True, then the sequence order is shuffled at every iteration (and shuffled back at the end).
msa_tokens = tokenized_msa[:,:200]
# If use_pdf=True, generate tokens by sampling from the logits at temperature T
# If save_all=True, then the first dimension of generated_tokens is the number of iterations
# If rand_perm=True, then the sequence order is shuffled at every iteration (and shuffled back at the end)
generated_tokens = IM_class.generate_all_msa(msa_tokens, iterations, use_pdf=False, T=1, save_all=True, rand_perm=True)
generated_tokens = IM_class.print_tokens(generated_tokens)
print("Shape of the tokenized generated sequences: ", generated_tokens.shape)
- If
use_pdf
=True, generate tokens by sampling from the logits at temperatureT
. - If
save_all
=True, then the first dimension of generated_tokens is the number of iterations. - If
rand_perm
=True, then the sequence order is shuffled at every iteration (and shuffled back at the end). - If
use_rnd_ctx
=False, then the context isall_context
and it’s the same at each iteration.
ancestor = tokenized_msa[:,:10]
all_context = tokenized_msa[:,10:210]
generated_tokens = IM_class.generate_with_context_msa(ancestor, iterations, use_pdf=False, T=1, all_context=all_context,
use_rnd_ctx=False, save_all=True, rand_perm=True)
generated_tokens = IM_class.print_tokens(generated_tokens)
print("Shape of the tokenized generated sequences: ", generated_tokens.shape)
- If
use_pdf
=True, generate tokens by sampling from the logits at temperatureT
. - If
save_all
=True, then the first dimension of generated_tokens is the number of iterations. - If
rand_perm
=True, then the sequence order is shuffled at every iteration (and shuffled back at the end). - If
use_rnd_ctx
=True, then the context is a different sub-MSA sampled at each iteration from the full MSA (the entire MSA is given as first entry ofall_context
, the depth of the sub-MSA is given by the second entry ofall_context
).
ancestor = tokenized_msa[:,:10]
all_context = (tokenized_msa[:,10:], 200)
generated_tokens = IM_class.generate_with_context_msa(ancestor, iterations, use_pdf=False, T=1, all_context=all_context,
use_rnd_ctx=True, use_two_msas=False, mode="same", save_all=True, rand_perm=True)
generated_tokens = IM_class.print_tokens(generated_tokens)
# If save_all=True, then the first dimension of generated_tokens is the number of iterations
print("Shape of the tokenized generated sequences: ", generated_tokens.shape)
If you want to sample sequences from two different MSAs separately you
can use the following parameters: - use_rnd_ctx
=True, same as before -
use_two_msas
=True, if you want to sample from two different MSAs given
as a tuple in the first entry of all_context
(the second entry of
all_context
is the depth of each sub-MSA). - mode
, is the sampling
mode, if mode
=“same” then the same number of sequences is sampled from
each MSA, if mode
=“ratio” then it samples a number of sequences from
each MSA proportional to the current iteration, starts with all
sequences from the first MSA and no sequences from the second MSA, ends
with no sequences from the first MSA and all sequences from the second
MSA. - warm_up
, used only if mode
=“ratio”, is the number of
iterations before starting to sample from the second MSA while
cool_down
is the number of iterations (before the end) when the
sampling from the first MSA is stopped.
ancestor = tokenized_msa[:,:10]
all_context = ((tokenized_msa[:,10:1000], tokenized_msa[:,1000:]), 200)
generated_tokens = IM_class.generate_with_context_msa(ancestor, iterations, use_pdf=False, T=1, all_context=all_context,
use_rnd_ctx=True, use_two_msas=True, mode="same", warm_up=0, cool_down=0, save_all=True, rand_perm=True)
generated_tokens = IM_class.print_tokens(generated_tokens)
# If save_all=True, then the first dimension of generated_tokens is the number of iterations
print("Shape of the tokenized generated sequences: ", generated_tokens.shape)
gen_MSAs(filepath="examples",
filename=["PF00072.fasta"],
new_dir="results",
pdf=False,
T=1,
sample_all=False,
Iters=200,
pmask=0.1,
num=[600],
depth=1,
generate=False,
print_all=False,
range_vals=False,
phylo_w=False)