Skip to content

Commit

Permalink
Merge pull request #5 from RosettaCommons/document-code
Browse files Browse the repository at this point in the history
Document code
  • Loading branch information
jeffreyruffolo authored Sep 27, 2021
2 parents 1366538 + a9735a0 commit 7f72774
Show file tree
Hide file tree
Showing 28 changed files with 332 additions and 578 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@ scripts/
temp/
.DS_Store
venv/
pred_*/
pred*/
8 changes: 6 additions & 2 deletions annotate_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def _get_args():
type=str,
default="CA",
help="Output branch to use attention from.")
parser.add_argument('--use_gpu', default=False, action="store_true")

return parser.parse_args()

Expand All @@ -81,9 +82,12 @@ def _cli():
attention_branch = args.attention_branch.lower()

device_type = 'cuda' if torch.cuda.is_available(
) and args.try_gpu else 'cpu'
) and args.use_gpu else 'cpu'
device = torch.device(device_type)

if not os.path.exists(model_file):
exit("No model file found at: {}".format(model_file))

model = load_model(model_file, eval_mode=True, device=device)

if not cdr_loop in cdr_names:
Expand All @@ -106,7 +110,7 @@ def _cli():
f.write(fasta_content)

cdr_i = cdr_indices(pdb_file, cdr_loop)
annotate_structure(model, temp_fasta, pdb_file, cdr_i, attention_branch)
annotate_structure(model, temp_fasta, out_file, cdr_i, attention_branch)


if __name__ == '__main__':
Expand Down
103 changes: 45 additions & 58 deletions data/TestSetList.txt
Original file line number Diff line number Diff line change
@@ -1,106 +1,93 @@
PDB_ID
1bey
1cz8
1dlf
1fns
1gig
1jfq
1jpt
1mfa
1mim
1mlb
1mqk
1nlb
1oaq
1seq
1x9q
1sy6
1yy8
2adf
2d7t
2e27
2fb4
2fbj
2hwz
2r8s
2v17
2vxv
2w60
2xwt
2ypv
3e8u
3eo0
3eo9
3g5y
3giz
3gkw
3gnm
3go1
3hc4
3hnt
3i9g
3ifl
3liz
3lmj
3m8o
3mlr
3mxw
3nfs
3nps
3o2d
3oz9
3p0y
3pp3
3qwo
3t65
3u0t
3umt
3v0w
4cni
4dn3
4f57
4g5z
4g6k
4h0h
4h20
4hkz
4hpy
4i77
4irz
4kaq
4m6n
4nyl
4nzu
4O02
4NYL
1BEY
4QXG
5XXY
4OJF
1MIM
5Y9K
5NHW
5N2K
4G5Z
4DN3
5WUV
1YY8
5KMV
3NFS
4OD2
5I5K
3EO9
5VKK
3EO0
5CSZ
4G6K
4M6N
3O2D
5VH3
5TRU
4D9Q
4I77
3C08
3QWO
1SY6
4IRZ
3B2U
3GKW
5GGQ
3PP3
3GIZ
4CNI
4X7S
4K3J
2HWZ
5SX4
5DK3
5JXE
6AND
3U0T
3S34
1CZ8
4KAQ
4YPG
4EDW
5L6Y
4HKZ
5GGU
3HMW
4od2
4ojf
4qxg
4x7s
4ypg
5csz
5dk3
5ggq
5ggu
5i5k
5jxe
5kmv
5l6y
5n2k
5nhw
5sx4
5tru
5vh3
5wuv
5xxy
5y9k
6and
34 changes: 30 additions & 4 deletions deepab/analysis/attention_analysis.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
from typing import Dict, Tuple
import torch

from deepab.models.AbResNet import AbResNet
from deepab.util.model_out import get_inputs_from_fasta


def get_HW_attn_for_model_input(model, model_input):
def get_HW_attn_for_model_input(
model: AbResNet,
model_input: torch.LongTensor,
):
"""
Gets row- and column-wise attention matrices for one-hot encoded input
"""
with torch.no_grad():
hw_attn = model.forward_attn(model_input)

Expand All @@ -13,14 +21,26 @@ def get_HW_attn_for_model_input(model, model_input):
return hw_attn


def get_HW_attn_for_fasta(model, fasta_file):
def get_HW_attn_for_fasta(
model: AbResNet,
fasta_file: str,
):
"""
Gets row- and column-wise attention matrices for fasta file
"""
model_input = get_inputs_from_fasta(fasta_file)
hw_attn = get_HW_attn_for_model_input(model, model_input)

return hw_attn


def get_mean_range_attn(attn, r):
def get_mean_range_attn(
attn: torch.FloatTensor,
r: Tuple[int],
):
"""
Calculates average attention on other whole sequence for given residue range
"""
att_H, att_W = attn
range_att_H = att_H[r[0]:r[1], :, r[0]:r[1]].mean(0)
range_att_W = att_W[r[0]:r[1], r[0]:r[1]].mean(0)
Expand All @@ -34,7 +54,13 @@ def get_mean_range_attn(attn, r):
return range_seq_attn, attn_mat


def get_cdr_attn_dict(attn, cdr_range_dict):
def get_cdr_attn_dict(
attn: torch.FloatTensor,
cdr_range_dict: Dict[str, Tuple[int, int]],
):
"""
Calculates attention on whole sequence for each CDR loop
"""
cdr_attn_dict = {
cdr: get_mean_range_attn(attn, r)
for cdr, r in cdr_range_dict.items()
Expand Down
63 changes: 39 additions & 24 deletions deepab/analysis/design_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
from deepab.util.util import load_full_seq, one_hot_seq, lev_distance, _aa_dict, _rev_aa_dict


def generate_pssm(model: AbResNet,
fasta_file: str,
teacher_forcing_ratio: float = 1.):
def generate_pssm(
model: AbResNet,
fasta_file: str,
teacher_forcing_ratio: float = 1.,
):
wt_seq = get_inputs_from_fasta(fasta_file)
pssm = model.get_lstm_pssm(wt_seq,
teacher_forcing_ratio=teacher_forcing_ratio)
Expand All @@ -20,8 +22,11 @@ def generate_pssm(model: AbResNet,
return pssm


def get_cce_for_inputs(model: Union[ModelEnsemble, AbResNet],
inputs: torch.Tensor):
def get_cce_for_inputs(
model: Union[ModelEnsemble, AbResNet],
inputs: torch.Tensor,
):

with torch.no_grad():
cross_entropy = torch.nn.CrossEntropyLoss(ignore_index=-999)
mut_logits = model(inputs)
Expand All @@ -40,9 +45,11 @@ def get_cce_for_inputs(model: Union[ModelEnsemble, AbResNet],
return cce


def get_fasta_cce(model: Union[ModelEnsemble, AbResNet],
fasta_file: str,
device: str = None):
def get_fasta_cce(
model: Union[ModelEnsemble, AbResNet],
fasta_file: str,
device: str = None,
):
inputs = get_inputs_from_fasta(fasta_file)
if type(device) != type(None):
model = model.to(device)
Expand All @@ -53,24 +60,30 @@ def get_fasta_cce(model: Union[ModelEnsemble, AbResNet],
return cce


def get_dcce(model: Union[ModelEnsemble, AbResNet], des_fasta: str,
wt_fasta: str, device: str):
def get_dcce(
model: Union[ModelEnsemble, AbResNet],
des_fasta: str,
wt_fasta: str,
device: str,
):
des_cce = get_fasta_cce(model, des_fasta, device)
wt_cce = get_fasta_cce(model, wt_fasta, device)

# pssm = generate_pssm(model.models[0], wt_fasta, teacher_forcing_ratio=1)
# nl_pssm = -np.log(pssm)
pssm = generate_pssm(model.models[0], wt_fasta, teacher_forcing_ratio=1)
nl_pssm = -np.log(pssm)

dcce = des_cce - wt_cce

return dcce, des_cce, wt_cce


def get_ld_balanced_mutants(wt_fasta: str,
mut_positions,
num_seqs: int = 500,
min_ld: int = 1,
max_ld: int = None):
def get_ld_balanced_mutants(
wt_fasta: str,
mut_positions,
num_seqs: int = 500,
min_ld: int = 1,
max_ld: int = None,
):
if max_ld == None:
max_ld = len(mut_positions)

Expand All @@ -95,13 +108,15 @@ def get_ld_balanced_mutants(wt_fasta: str,
return mut_seqs


def get_ld_balanced_cce(model,
wt_fasta,
mut_positions,
device,
num_seqs=500,
min_ld=1,
max_ld=None):
def get_ld_balanced_cce(
model,
wt_fasta,
mut_positions,
device,
num_seqs=500,
min_ld=1,
max_ld=None,
):
wt_inputs = get_inputs_from_fasta(wt_fasta)
wt_seq = load_full_seq(wt_fasta)
mut_seqs = get_ld_balanced_mutants(wt_fasta,
Expand Down
Loading

0 comments on commit 7f72774

Please sign in to comment.