Skip to content

Commit

Permalink
SidechainNet v1.0 update
Browse files Browse the repository at this point in the history
SidechainNet v1.0 update
  • Loading branch information
jonathanking authored Nov 2, 2023
2 parents cfb36b7 + 89fdd1e commit ef5c746
Show file tree
Hide file tree
Showing 63 changed files with 14,854 additions and 1,194 deletions.
21 changes: 18 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
# Pytorch saved files and associated ProteinNet files
*.pt
*.pkl
*_ids.txt
*.pdb
sidechain*pkl
training*pkl
validation.pkl
data/*/*.pkl

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down Expand Up @@ -107,7 +110,7 @@ ENV/
# mypy
.mypy_cache/

# custom project ignore
# Other custom project ignores for research code/data
errors/
*.cif
.idea/
Expand All @@ -128,4 +131,16 @@ resources/dev/*
*slurm*.out
*_*.out
*wandb*
*.chkpt
*.chkpt
sidechainnet/utils/minimize_old.py
sidechainnet/utils/openmmpdb.py
sidechainnet/research/analysis/plots
*.prof_stats*
test_minimize_alpha.png
sidechainnet/research/reconfigure_alignments/**/*.txt
sidechainnet/research/reconfigure_alignments/**/*.pkl
sidechainnet/research/evaluate_energy/**/*.csv
seq_dict.pkl
sidechainnet/research/analysis/minimization/**/*.png
sidechainnet/research/analysis/minimization/**/*.pkl
sidechainnet/research/analysis/minimization/**/*.pdb
220 changes: 79 additions & 141 deletions README.md

Large diffs are not rendered by default.

Binary file modified docs/_static/structure_example.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@
package_data={
"astral_data": ["resources/astral_data.txt"],
"full_protein_dssp": ["resources/full_protein_dssp_annotations.json"],
"single_domain_dssp": ["resources/single_domain_dssp_annotations.json"]
},
"single_domain_dssp": ["resources/single_domain_dssp_annotations.json"],
"amber_build_params": ["resources/build_params.pkl"]
},

# Allows `setup.py test` to work correctly with pytest
setup_requires=[] + pytest_runner,
Expand Down
3 changes: 2 additions & 1 deletion sidechainnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
from .structure.BatchedStructureBuilder import BatchedStructureBuilder
from .utils.load import load
from .utils.download import VALID_SPLITS, DATA_SPLITS
from .utils.measure import GLOBAL_PAD_CHAR
from .structure.build_info import GLOBAL_PAD_CHAR
from .create import create, create_custom, get_proteinnet_ids, generate_all
from . import utils
from .dataloaders.SCNProtein import SCNProtein, OPENMM_FORCEFIELDS


versions = get_versions()
Expand Down
107 changes: 88 additions & 19 deletions sidechainnet/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@

import argparse
from collections import namedtuple
import multiprocessing
import os
from multiprocessing import Pool, cpu_count
import pkg_resources
from sidechainnet.utils.sequence import ONE_TO_THREE_LETTER_MAP
import time

import numpy as np
import prody as pr
Expand All @@ -39,9 +41,9 @@
from sidechainnet.utils.manual_adjustment import (manually_adjust_data,
manually_correct_mask,
needs_manual_adjustment)
from sidechainnet.utils.measure import NUM_COORDS_PER_RES
from sidechainnet.utils.organize import get_validation_split_identifiers_from_pnid_list, load_data, organize_data, save_data
from sidechainnet.utils.parse import parse_raw_proteinnet
from sidechainnet.dataloaders.SCNDataset import SCNDataset

PNID_CSV_FILE = None

Expand All @@ -50,7 +52,7 @@

ArgsTuple = namedtuple(
"ArgsTuple", "casp_version thinning proteinnet_in proteinnet_out "
"sidechainnet_out regenerate_scdata limit")
"sidechainnet_out regenerate_scdata limit num_cores")


def combine(pn_entry, sc_entry, aligner, pnid):
Expand All @@ -74,7 +76,10 @@ def combine(pn_entry, sc_entry, aligner, pnid):

# If there is no corresponding ProteinNet entry, we create a template entry
if pn_entry is None:
seq = get_sequence_from_pnid(pnid)
if "1GJJ" in pnid:
seq = sc_entry['seq']
else:
seq = get_sequence_from_pnid(pnid)
pn_entry = {
"primary": seq,
"evolutionary": np.zeros((len(seq), 21)),
Expand Down Expand Up @@ -106,6 +111,7 @@ def combine(pn_entry, sc_entry, aligner, pnid):
mask = manually_correct_mask(pnid, pn_entry, mask)
new_entry["ang"] = expand_data_with_mask(ang, mask)
new_entry["crd"] = expand_data_with_mask(crd, mask)
new_entry["crd"] = new_entry["crd"].reshape(len(new_entry["seq"]), -1, 3)
new_entry["sec"] = expand_data_with_mask(dssp, mask)
new_entry["ums"] = make_unmodified_seq_entry(new_entry["seq"], unmod_seq, mask)
new_entry["mod"] = expand_data_with_mask(is_mod, mask)
Expand All @@ -115,7 +121,7 @@ def combine(pn_entry, sc_entry, aligner, pnid):
length = len(pn_entry["primary"])
for k, v in new_entry.items():
if k == "crd":
if len(v) // NUM_COORDS_PER_RES != length:
if len(v) != length:
return {}, "failed"
elif k == "ums":
if len(v.split(" ")) != length:
Expand Down Expand Up @@ -194,17 +200,22 @@ def get_tuple(pndata, scdata, pnid):


def format_sidechainnet_path(casp_version, training_split):
"""Returns a string representing a .pkl file for a CASP version and training set."""
"""Return a string representing a .pkl file for a CASP version and training set."""
if casp_version == "debug":
return "sidechainnet_debug.pkl"
elif casp_version == 12 and training_split == "scnmin":
return "sidechainnet_scnmin_ommloss_paper.pkl"
elif casp_version == 12 and training_split == "scnunmin":
return "sidechainnet_scnunmin_ommloss_paper.pkl"
return f"sidechainnet_casp{casp_version}_{training_split}.pkl"


def create(casp_version=12,
thinning=30,
sidechainnet_out="./sidechainnet_data",
regenerate_scdata=False,
limit=None):
limit=None,
num_cores=multiprocessing.cpu_count()):
"""Generate the requested SidechainNet dataset and save pickled result files.
This function replicates CLI behavior of calling `python sidechainnet/create.py`.
Expand Down Expand Up @@ -234,7 +245,7 @@ def create(casp_version=12,
proteinnet_out = proteinnet_in

args = ArgsTuple(casp_version, thinning, proteinnet_in, proteinnet_out,
sidechainnet_out, regenerate_scdata, limit)
sidechainnet_out, regenerate_scdata, limit, num_cores)
main(args)


Expand All @@ -247,10 +258,14 @@ def _create(args):
pnids = pnids[:args.limit] # Limit the length of the list for debugging

# Using the ProteinNet IDs as a guide, download the relevant sidechain data
sc_only_data, sc_filename = download_sidechain_data(pnids, args.sidechainnet_out,
args.casp_version, args.thinning,
args.limit, args.proteinnet_in,
args.regenerate_scdata)
sc_only_data, sc_filename = download_sidechain_data(pnids,
args.sidechainnet_out,
args.casp_version,
args.thinning,
args.limit,
args.proteinnet_in,
args.regenerate_scdata,
num_cores=args.num_cores)

# Finally, unify the sidechain data with ProteinNet
sidechainnet_raw = combine_datasets(args.proteinnet_out, sc_only_data)
Expand All @@ -264,6 +279,9 @@ def _create(args):

def _create_all(args):
"""Generate all thinnings of a particular CASP dataset, starting with the largest."""
from sidechainnet.utils.download import _init_dssp_data
_init_dssp_data()

# First, parse raw proteinnet files into Python dictionaries for convenience
pnids = get_proteinnet_ids(casp_version=args.casp_version, split="all", thinning=100)
pnids = pnids[:args.limit] # Limit the length of the list for debugging
Expand All @@ -276,13 +294,17 @@ def _create_all(args):
100,
args.limit,
args.proteinnet_in,
regenerate_scdata=args.regenerate_scdata)
regenerate_scdata=args.regenerate_scdata,
num_cores=args.num_cores)

# Finally, unify the sidechain data with ProteinNet
sidechainnet_raw_100 = combine_datasets(args.proteinnet_out, sc_only_data)

# Generate debug dataset with 200 training examples
sc_outfile = os.path.join(args.sidechainnet_out, format_sidechainnet_path("debug", 0))
sc_outfile = os.path.join(
args.sidechainnet_out,
format_sidechainnet_path("debug", 0).replace("debug",
f"debug{args.casp_version}"))
debug = organize_data(sidechainnet_raw_100,
args.casp_version,
thinning=100,
Expand All @@ -304,7 +326,8 @@ def create_custom(pnids,
output_filename,
sidechainnet_out="./sidechainnet_data",
short_description="Custom SidechainNet dataset.",
regenerate_scdata=False):
regenerate_scdata=False,
num_cores=multiprocessing.cpu_count()):
"""Generate a custom SidechainNet dataset from user-specified ProteinNet IDs.
This function utilizes a concatedated version of ProteinNet generated by the author.
Expand Down Expand Up @@ -359,7 +382,8 @@ def create_custom(pnids,
limit=None,
proteinnet_in=proteinnet_in,
regenerate_scdata=regenerate_scdata,
output_name=intermediate_filename)
output_name=intermediate_filename,
num_cores=num_cores)

# Finally, unify the sidechain data with ProteinNet
sidechainnet_raw = combine_datasets(proteinnet_out, sc_only_data)
Expand All @@ -376,7 +400,7 @@ def create_custom(pnids,
"To load the data in a different format, use sidechainnet.load with the desired\n"
f"options and set 'local_scn_path={sidechainnet_outfile}'.")

return sidechainnet_dict
return SCNDataset(sidechainnet_dict)


def get_proteinnet_ids(casp_version, split, thinning=None):
Expand Down Expand Up @@ -446,7 +470,8 @@ def make_colname(cur_split):
return list(PNID_CSV_FILE[PNID_CSV_FILE[colname]].index.values)


def generate_all():
def generate_all(
num_cores=multiprocessing.cpu_count(), limit=None, regenerate_scdata=False):
"""Generate all SidechainNet datasets for curation and upload."""
import time
import sidechainnet as scn
Expand All @@ -456,7 +481,40 @@ def generate_all():
casps = list(range(7, 13))[::-1]
for c in casps:
print("CASP", c)
scn.create(c, "all", regenerate_scdata=False)
scn.create(c,
"all",
regenerate_scdata=regenerate_scdata,
num_cores=num_cores,
limit=limit)


def generate_all_from_proteinnet(proteinnet_dir,
sidechainnet_out,
num_cores=multiprocessing.cpu_count(),
limit=None,
regenerate_scdata=False):
"""Generate all SCN datasets for curation & upload."""
t = time.localtime()
timestamp = time.strftime('%b-%d-%Y-%H%M', t)
pr.startLogfile(f"sidechainnet_generateall_{timestamp}")
casps = list(range(7, 13))[::-1]
for c in casps:
print("CASP", c)
pin = os.path.join(proteinnet_dir, 'casp'+str(c), 'casp'+str(c))
pout = os.path.join(proteinnet_dir, 'casp' + str(c), 'pkl')
os.makedirs(pout, exist_ok=True)
parse_raw_proteinnet(pin, pout, thinning=100)
args = ArgsTuple(
casp_version=c,
thinning='all',
proteinnet_in=pin,
proteinnet_out=pout,
sidechainnet_out=sidechainnet_out,
regenerate_scdata=regenerate_scdata,
limit=limit,
num_cores=num_cores)
# Create all thinnings for the current CASP version
_create_all(args)


def main(args_tuple):
Expand Down Expand Up @@ -503,13 +561,24 @@ def main(args_tuple):
default='30',
help='Which \'thinning\' of the ProteinNet training '
'set to parse. {30,50,70,90,95,100}. Default 30.')
parser.add_argument('-c',
'--casp_version',
type=str,
default="12",
help='CASP version.')
parser.add_argument('--num_cores',
type=int,
default=multiprocessing.cpu_count(),
help='Number of CPU cores to use for dataset creation.')
parser.add_argument(
'--regenerate_scdata',
action="store_true",
help=('If True, then regenerate the sidechain-only data even if it already exists'
' locally.'))
args = parser.parse_args()
if args.casp_version != "debug":
args.casp_version = int(args.casp_version)
args_tuple = ArgsTuple(args.casp_version, args.thinning, args.proteinnet_in,
args.proteinnet_out, args.sidechainnet_out,
args.regenerate_scdata, args.limit)
args.regenerate_scdata, args.limit, args.num_cores)
main(args_tuple)
Loading

0 comments on commit ef5c746

Please sign in to comment.