Skip to content

Commit

Permalink
Local dataset (#208)
Browse files Browse the repository at this point in the history
* add load local pdb_files to ProteinGraphDataset

* load local pdb_files from a list

* test and black and isort and add CHANGELOG.md

* docstring formatting

* comment jupyter; just run  and ssh or vscode to interact with container

* install dependency of jupyter nbextension

* Fixed Dockerfile and unit tests

* Update chage log

Co-authored-by: xutingfeng <xutingfeng@big.ac.cn>
Co-authored-by: Arian Jamasb <arjamasb@gmail.com>
Co-authored-by: ryan <ryan.greenhalgh@hotmail.co.uk>
  • Loading branch information
4 people authored Sep 16, 2022
1 parent c07fc78 commit ae6ef9d
Show file tree
Hide file tree
Showing 5 changed files with 290 additions and 22 deletions.
17 changes: 17 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
### 1.5.2

### GRN
* [Bugfix] - [#208](https://github.com/a-r-j/graphein/pull/208) - Resolves SSL issues with RegNetwork.


#### ML
* [Feature] - [#208](https://github.com/a-r-j/graphein/pull/208) support for loading local pdb files by ``ProteinGraphDataset`` and ``InMemoryProteinGraphDataset``.
>by adding a params:`pdb_paths` and set the `self.raw_dir` to the root path(`self.pdb_path`) of pdb_paths list (the root path should be only one, pdb files should be under the same folder).
>
>it will works from loading pdb files from the `self.pdb_path` instead of loading from self.raw.
> If desire to download from af2 or pdb, just set `pdb_paths` to `None` and it goes back to the former version.
#### CI
* [Bugfix] - [#208](https://github.com/a-r-j/graphein/pull/208) explicitly installs `jupyter_contrib_nbextensions` in Docker.


### 1.5.1

#### Protein
Expand Down
4 changes: 2 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,7 @@ RUN conda install -c fvcore -c iopath -c conda-forge fvcore iopath
RUN conda install -c pytorch3d pytorch3d
RUN conda install -c dglteam dgl
RUN conda install -c salilab dssp

RUN conda install -c conda-forge ipywidgets
RUN jupyter nbextension enable --py widgetsnbextension

RUN export CUDA=$(python -c "import torch; print('cu'+torch.version.cuda.replace('.',''))") \
&& export TORCH=$(python -c "import torch; print(torch.__version__)") \
Expand All @@ -54,6 +52,8 @@ RUN export CUDA=$(python -c "import torch; print('cu'+torch.version.cuda.replace
&& pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html --no-cache-dir \
&& pip install torch-geometric --no-cache-dir

RUN pip install jupyter_contrib_nbextensions
RUN jupyter nbextension enable --py widgetsnbextension

# Testing
# docker-compose -f docker-compose.cpu.yml up -d --build
Expand Down
18 changes: 14 additions & 4 deletions graphein/grn/parse_regnetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import pandas as pd
import wget
import ssl

from graphein.utils.utils import filter_dataframe, ping

Expand Down Expand Up @@ -41,10 +42,10 @@ def _download_RegNetwork(
"RegNetwork is not available. Please check your internet connection or verify at: http://www.regnetworkweb.org"
)

mouse_url = "http://regnetworkweb.org/download/mouse.zip"
mouse_url = "https://regnetworkweb.org/download/mouse.zip"

if network_type == "human":
human_url = "http://www.regnetworkweb.org/download/human.zip"
human_url = "https://regnetworkweb.org/download/human.zip"
url = human_url
elif network_type == "mouse":
url = mouse_url
Expand All @@ -66,8 +67,12 @@ def _download_RegNetwork(
# Download data and unzip
if not os.path.exists(file):
log.info("Downloading RegNetwork ...")
# switch ssl context for unverified download
default_https_context = ssl._create_default_https_context
ssl._create_default_https_context = ssl._create_unverified_context
wget.download(url, compressed_file)

# switch ssl context back to default
ssl._create_default_https_context = default_https_context
with zipfile.ZipFile(compressed_file, "r") as zip_ref:
zip_ref.extractall(out_dir)

Expand All @@ -80,7 +85,7 @@ def _download_RegNetwork_regtypes(root_dir: Optional[Path] = None) -> str:
:param root_dir: Path object specifying the location to download RegNetwork to
"""
url = "http://www.regnetworkweb.org/download/RegulatoryDirections.zip"
url = "https://regnetworkweb.org/download/RegulatoryDirections.zip"

if root_dir is None:
root_dir = Path(__file__).parent.parent.parent / "datasets"
Expand All @@ -94,7 +99,12 @@ def _download_RegNetwork_regtypes(root_dir: Optional[Path] = None) -> str:
# Download data and unzip
if not os.path.exists(file):
log.info("Downloading RegNetwork reg types ...")
# switch ssl context for unverified download
default_https_context = ssl._create_default_https_context
ssl._create_default_https_context = ssl._create_unverified_context
wget.download(url, compressed_file)
# switch ssl context back to default
ssl._create_default_https_context = default_https_context

with zipfile.ZipFile(compressed_file, "r") as zip_ref:
zip_ref.extractall(out_dir)
Expand Down
66 changes: 57 additions & 9 deletions graphein/ml/datasets/torch_geometric_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
self,
root: str,
name: str,
pdb_paths: Optional[List[str]] = None,
pdb_codes: Optional[List[str]] = None,
uniprot_ids: Optional[List[str]] = None,
graph_label_map: Optional[Dict[str, torch.Tensor]] = None,
Expand Down Expand Up @@ -72,6 +73,8 @@ def __init__(
:type root: str
:param name: Name of the dataset. Will be saved to ``data_$name.pt``.
:type name: str
:param pdb_paths: List of full path of pdb files to load. Defaults to ``None``.
:type pdb_paths: Optional[List[str]], optional
:param pdb_codes: List of PDB codes to download and parse from the PDB.
Defaults to None.
:type pdb_codes: Optional[List[str]], optional
Expand Down Expand Up @@ -135,6 +138,23 @@ def __init__(
else None
)

self.pdb_paths = pdb_paths
if self.pdb_paths is None:
if self.pdb_codes and self.uniprot_ids:
self.structures = self.pdb_codes + self.uniprot_ids
elif self.pdb_codes:
self.structures = pdb_codes
elif self.uniprot_ids:
self.structures = uniprot_ids
# Use local saved pdb_files instead of download or move them to self.root/raw dir
else:
if isinstance(self.pdb_paths, list):
self.structures = [
os.path.splitext(os.path.split(pdb_path)[-1])[0]
for pdb_path in self.pdb_paths
]
self.pdb_path, _ = os.path.split(self.pdb_paths[0])

if self.pdb_codes and self.uniprot_ids:
self.structures = self.pdb_codes + self.uniprot_ids
elif self.pdb_codes:
Expand All @@ -157,6 +177,7 @@ def __init__(
self.graph_transformation_funcs = graph_transformation_funcs
self.pdb_transform = pdb_transform
self.num_cores = num_cores
self.af_version = af_version
super().__init__(
root,
transform=transform,
Expand All @@ -176,6 +197,13 @@ def processed_file_names(self) -> List[str]:
"""Name of the processed file."""
return [f"data_{self.name}.pt"]

@property
def raw_dir(self) -> str:
if self.pdb_paths is not None:
return self.pdb_path # replace raw dir with user local pdb_path
else:
return os.path.join(self.root, "raw")

def download(self):
"""Download the PDB files from RCSB or Alphafold."""
self.config.pdb_dir = Path(self.raw_dir)
Expand Down Expand Up @@ -298,7 +326,8 @@ def process(self):
class ProteinGraphDataset(Dataset):
def __init__(
self,
root,
root: str,
pdb_paths: Optional[List[str]] = None,
pdb_codes: Optional[List[str]] = None,
uniprot_ids: Optional[List[str]] = None,
# graph_label_map: Optional[Dict[str, int]] = None,
Expand Down Expand Up @@ -327,6 +356,8 @@ def __init__(
:param root: Root directory where the dataset should be saved.
:type root: str
:param pdb_paths: List of full path of pdb files to load. Defaults to ``None``.
:type pdb_paths: Optional[List[str]], optional
:param pdb_codes: List of PDB codes to download and parse from the PDB.
Defaults to ``None``.
:type pdb_codes: Optional[List[str]], optional
Expand Down Expand Up @@ -388,14 +419,22 @@ def __init__(
if uniprot_ids is not None
else None
)

if self.pdb_codes and self.uniprot_ids:
self.structures = self.pdb_codes + self.uniprot_ids
elif self.pdb_codes:
self.structures = pdb_codes
elif self.uniprot_ids:
self.structures = uniprot_ids
self.af_version = af_version
self.pdb_paths = pdb_paths
if self.pdb_paths is None:
if self.pdb_codes and self.uniprot_ids:
self.structures = self.pdb_codes + self.uniprot_ids
elif self.pdb_codes:
self.structures = pdb_codes
elif self.uniprot_ids:
self.structures = uniprot_ids
# Use local saved pdb_files instead of download or move them to self.root/raw dir
else:
if isinstance(self.pdb_paths, list):
self.structures = [
os.path.splitext(os.path.split(pdb_path)[-1])[0]
for pdb_path in self.pdb_paths
]
self.pdb_path, _ = os.path.split(self.pdb_paths[0])

# Labels & Chains

Expand Down Expand Up @@ -424,6 +463,7 @@ def __init__(
self.num_cores = num_cores
self.pdb_transform = pdb_transform
self.graph_transformation_funcs = graph_transformation_funcs
self.af_version = af_version
super().__init__(
root,
transform=transform,
Expand All @@ -450,6 +490,13 @@ def processed_file_names(self) -> List[str]:
else:
return [f"{pdb}.pt" for pdb in self.structures]

@property
def raw_dir(self) -> str:
if self.pdb_paths is not None:
return self.pdb_path # replace raw dir with user local pdb_path
else:
return os.path.join(self.root, "raw")

def validate_input(self):
if self.graph_label_map is not None:
assert len(self.structures) == len(
Expand Down Expand Up @@ -554,6 +601,7 @@ def divide_chunks(l: List[str], n: int = 2) -> Generator:

# Create graph objects
file_names = [f"{self.raw_dir}/{pdb}.pdb" for pdb in pdbs]

graphs = construct_graphs_mp(
pdb_path_it=file_names,
config=self.config,
Expand Down
Loading

0 comments on commit ae6ef9d

Please sign in to comment.