Skip to content

Commit

Permalink
Add string sequences to Dataloader via "batch.str_seqs".
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathanking committed Jun 18, 2021
2 parents c580c00 + 8bf2cc9 commit 96912d2
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 5 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ The `batch` variable above is a `collections.namedtuple` that has the following
| `batch.pids` | Tuple of ProteinNet/SidechainNet IDs for proteins in this batch |
| `batch.seqs` | Tensor of sequences, either as integers or as one-hot vectors depending on value of `scn.load(... seq_as_onehot)` |
| `batch.int_seqs` | Tensor of sequences in integer sequence format |
| `batch.str_seqs` | Tuple of sequences as strings (unpadded) |
| `batch.msks` | Tensor of missing residue masks, (redundant with padding in data) |
| `batch.evos` | Tensor of Position Specific Scoring Matrix + Information Content |
| `batch.secs` | Tensor of secondary structure, either as integers or one-hot vectors depending on value of `scn.load(... seq_as_onehot)` |
Expand Down
4 changes: 3 additions & 1 deletion sidechainnet/dataloaders/ProteinDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(self,

# Organize data
self.seqs = [VOCAB.str2ints(s, add_sos_eos) for s in scn_data_split['seq']]
self.str_seqs = scn_data_split['seq']
self.angs = scn_data_split['ang']
self.crds = scn_data_split['crd']
self.msks = [
Expand Down Expand Up @@ -50,6 +51,7 @@ def _sort_by_length(self, reverse_sort):
enumerate(self.angs), key=lambda x: x[1].shape[0], reverse=reverse_sort)
]
self.seqs = [self.seqs[i] for i in sorted_len_indices]
self.str_seqs = [self.str_seqs[i] for i in sorted_len_indices]
self.angs = [self.angs[i] for i in sorted_len_indices]
self.crds = [self.crds[i] for i in sorted_len_indices]
self.msks = [self.msks[i] for i in sorted_len_indices]
Expand All @@ -65,7 +67,7 @@ def __len__(self):
def __getitem__(self, idx):
return (self.ids[idx], self.seqs[idx], self.msks[idx], self.evos[idx],
self.secs[idx], self.angs[idx], self.crds[idx], self.resolutions[idx],
self.mods[idx])
self.mods[idx], self.str_seqs[idx])

def __str__(self):
"""Describe this dataset to the user."""
Expand Down
10 changes: 6 additions & 4 deletions sidechainnet/dataloaders/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
Batch = collections.namedtuple("Batch",
"pids seqs msks evos secs angs "
"crds int_seqs seq_evo_sec resolutions is_modified "
"lengths")
"lengths str_seqs")


def get_collate_fn(aggregate_input, seqs_as_onehot=None):
Expand Down Expand Up @@ -60,7 +60,7 @@ def collate_fn(insts):
"""
# Instead of working with a list of tuples, we extract out each category of info
# so it can be padded and re-provided to the user.
pnids, sequences, masks, pssms, secs, angles, coords, resolutions, mods = list(zip(*insts))
pnids, sequences, masks, pssms, secs, angles, coords, resolutions, mods, str_seqs = list(zip(*insts))
lengths = tuple(len(s) for s in sequences)
max_batch_len = max(lengths)

Expand Down Expand Up @@ -98,7 +98,8 @@ def collate_fn(insts):
seq_evo_sec=None,
resolutions=resolutions,
is_modified=padded_mods,
lengths=lengths)
lengths=lengths,
str_seqs=str_seqs)

# Aggregated model input
elif aggregate_input:
Expand All @@ -117,7 +118,8 @@ def collate_fn(insts):
seq_evo_sec=seq_evo_sec,
resolutions=resolutions,
is_modified=padded_mods,
lengths=lengths)
lengths=lengths,
str_seqs=str_seqs)

return collate_fn

Expand Down

0 comments on commit 96912d2

Please sign in to comment.