Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Local dataset #208

Merged
merged 10 commits into from
Sep 16, 2022
17 changes: 17 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
### 1.5.2

### GRN
* [Bugfix] - [#208](https://github.com/a-r-j/graphein/pull/208) - Resolves SSL issues with RegNetwork.


#### ML
* [Feature] - [#208](https://github.com/a-r-j/graphein/pull/208) support for loading local pdb files by ``ProteinGraphDataset`` and ``InMemoryProteinGraphDataset``.
>by adding a params:`pdb_paths` and set the `self.raw_dir` to the root path(`self.pdb_path`) of pdb_paths list (the root path should be only one, pdb files should be under the same folder).
>
>it will works from loading pdb files from the `self.pdb_path` instead of loading from self.raw.
> If desire to download from af2 or pdb, just set `pdb_paths` to `None` and it goes back to the former version.

#### CI
* [Bugfix] - [#208](https://github.com/a-r-j/graphein/pull/208) explicitly installs `jupyter_contrib_nbextensions` in Docker.


### 1.5.1

#### Protein
Expand Down
4 changes: 2 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,7 @@ RUN conda install -c fvcore -c iopath -c conda-forge fvcore iopath
RUN conda install -c pytorch3d pytorch3d
RUN conda install -c dglteam dgl
RUN conda install -c salilab dssp

RUN conda install -c conda-forge ipywidgets
RUN jupyter nbextension enable --py widgetsnbextension

RUN export CUDA=$(python -c "import torch; print('cu'+torch.version.cuda.replace('.',''))") \
&& export TORCH=$(python -c "import torch; print(torch.__version__)") \
Expand All @@ -54,6 +52,8 @@ RUN export CUDA=$(python -c "import torch; print('cu'+torch.version.cuda.replace
&& pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html --no-cache-dir \
&& pip install torch-geometric --no-cache-dir

RUN pip install jupyter_contrib_nbextensions
RUN jupyter nbextension enable --py widgetsnbextension

# Testing
# docker-compose -f docker-compose.cpu.yml up -d --build
Expand Down
18 changes: 14 additions & 4 deletions graphein/grn/parse_regnetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import pandas as pd
import wget
import ssl

from graphein.utils.utils import filter_dataframe, ping

Expand Down Expand Up @@ -41,10 +42,10 @@ def _download_RegNetwork(
"RegNetwork is not available. Please check your internet connection or verify at: http://www.regnetworkweb.org"
)

mouse_url = "http://regnetworkweb.org/download/mouse.zip"
mouse_url = "https://regnetworkweb.org/download/mouse.zip"

if network_type == "human":
human_url = "http://www.regnetworkweb.org/download/human.zip"
human_url = "https://regnetworkweb.org/download/human.zip"
url = human_url
elif network_type == "mouse":
url = mouse_url
Expand All @@ -66,8 +67,12 @@ def _download_RegNetwork(
# Download data and unzip
if not os.path.exists(file):
log.info("Downloading RegNetwork ...")
# switch ssl context for unverified download
default_https_context = ssl._create_default_https_context
ssl._create_default_https_context = ssl._create_unverified_context
wget.download(url, compressed_file)

# switch ssl context back to default
ssl._create_default_https_context = default_https_context
with zipfile.ZipFile(compressed_file, "r") as zip_ref:
zip_ref.extractall(out_dir)

Expand All @@ -80,7 +85,7 @@ def _download_RegNetwork_regtypes(root_dir: Optional[Path] = None) -> str:

:param root_dir: Path object specifying the location to download RegNetwork to
"""
url = "http://www.regnetworkweb.org/download/RegulatoryDirections.zip"
url = "https://regnetworkweb.org/download/RegulatoryDirections.zip"

if root_dir is None:
root_dir = Path(__file__).parent.parent.parent / "datasets"
Expand All @@ -94,7 +99,12 @@ def _download_RegNetwork_regtypes(root_dir: Optional[Path] = None) -> str:
# Download data and unzip
if not os.path.exists(file):
log.info("Downloading RegNetwork reg types ...")
# switch ssl context for unverified download
default_https_context = ssl._create_default_https_context
ssl._create_default_https_context = ssl._create_unverified_context
wget.download(url, compressed_file)
# switch ssl context back to default
ssl._create_default_https_context = default_https_context

with zipfile.ZipFile(compressed_file, "r") as zip_ref:
zip_ref.extractall(out_dir)
Expand Down
66 changes: 57 additions & 9 deletions graphein/ml/datasets/torch_geometric_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
self,
root: str,
name: str,
pdb_paths: Optional[List[str]] = None,
pdb_codes: Optional[List[str]] = None,
uniprot_ids: Optional[List[str]] = None,
graph_label_map: Optional[Dict[str, torch.Tensor]] = None,
Expand Down Expand Up @@ -72,6 +73,8 @@ def __init__(
:type root: str
:param name: Name of the dataset. Will be saved to ``data_$name.pt``.
:type name: str
:param pdb_paths: List of full path of pdb files to load. Defaults to ``None``.
:type pdb_paths: Optional[List[str]], optional
:param pdb_codes: List of PDB codes to download and parse from the PDB.
Defaults to None.
:type pdb_codes: Optional[List[str]], optional
Expand Down Expand Up @@ -135,6 +138,23 @@ def __init__(
else None
)

self.pdb_paths = pdb_paths
if self.pdb_paths is None:
if self.pdb_codes and self.uniprot_ids:
self.structures = self.pdb_codes + self.uniprot_ids
elif self.pdb_codes:
self.structures = pdb_codes
elif self.uniprot_ids:
self.structures = uniprot_ids
# Use local saved pdb_files instead of download or move them to self.root/raw dir
else:
if isinstance(self.pdb_paths, list):
self.structures = [
os.path.splitext(os.path.split(pdb_path)[-1])[0]
for pdb_path in self.pdb_paths
]
self.pdb_path, _ = os.path.split(self.pdb_paths[0])

if self.pdb_codes and self.uniprot_ids:
self.structures = self.pdb_codes + self.uniprot_ids
elif self.pdb_codes:
Expand All @@ -157,6 +177,7 @@ def __init__(
self.graph_transformation_funcs = graph_transformation_funcs
self.pdb_transform = pdb_transform
self.num_cores = num_cores
self.af_version = af_version
super().__init__(
root,
transform=transform,
Expand All @@ -176,6 +197,13 @@ def processed_file_names(self) -> List[str]:
"""Name of the processed file."""
return [f"data_{self.name}.pt"]

@property
def raw_dir(self) -> str:
if self.pdb_paths is not None:
return self.pdb_path # replace raw dir with user local pdb_path
else:
return os.path.join(self.root, "raw")

def download(self):
"""Download the PDB files from RCSB or Alphafold."""
self.config.pdb_dir = Path(self.raw_dir)
Expand Down Expand Up @@ -298,7 +326,8 @@ def process(self):
class ProteinGraphDataset(Dataset):
def __init__(
self,
root,
root: str,
pdb_paths: Optional[List[str]] = None,
pdb_codes: Optional[List[str]] = None,
uniprot_ids: Optional[List[str]] = None,
# graph_label_map: Optional[Dict[str, int]] = None,
Expand Down Expand Up @@ -327,6 +356,8 @@ def __init__(

:param root: Root directory where the dataset should be saved.
:type root: str
:param pdb_paths: List of full path of pdb files to load. Defaults to ``None``.
:type pdb_paths: Optional[List[str]], optional
:param pdb_codes: List of PDB codes to download and parse from the PDB.
Defaults to ``None``.
:type pdb_codes: Optional[List[str]], optional
Expand Down Expand Up @@ -388,14 +419,22 @@ def __init__(
if uniprot_ids is not None
else None
)

if self.pdb_codes and self.uniprot_ids:
self.structures = self.pdb_codes + self.uniprot_ids
elif self.pdb_codes:
self.structures = pdb_codes
elif self.uniprot_ids:
self.structures = uniprot_ids
self.af_version = af_version
self.pdb_paths = pdb_paths
if self.pdb_paths is None:
if self.pdb_codes and self.uniprot_ids:
self.structures = self.pdb_codes + self.uniprot_ids
elif self.pdb_codes:
self.structures = pdb_codes
elif self.uniprot_ids:
self.structures = uniprot_ids
# Use local saved pdb_files instead of download or move them to self.root/raw dir
else:
if isinstance(self.pdb_paths, list):
self.structures = [
os.path.splitext(os.path.split(pdb_path)[-1])[0]
for pdb_path in self.pdb_paths
]
self.pdb_path, _ = os.path.split(self.pdb_paths[0])

# Labels & Chains

Expand Down Expand Up @@ -424,6 +463,7 @@ def __init__(
self.num_cores = num_cores
self.pdb_transform = pdb_transform
self.graph_transformation_funcs = graph_transformation_funcs
self.af_version = af_version
super().__init__(
root,
transform=transform,
Expand All @@ -450,6 +490,13 @@ def processed_file_names(self) -> List[str]:
else:
return [f"{pdb}.pt" for pdb in self.structures]

@property
def raw_dir(self) -> str:
if self.pdb_paths is not None:
return self.pdb_path # replace raw dir with user local pdb_path
else:
return os.path.join(self.root, "raw")

def validate_input(self):
if self.graph_label_map is not None:
assert len(self.structures) == len(
Expand Down Expand Up @@ -554,6 +601,7 @@ def divide_chunks(l: List[str], n: int = 2) -> Generator:

# Create graph objects
file_names = [f"{self.raw_dir}/{pdb}.pdb" for pdb in pdbs]

graphs = construct_graphs_mp(
pdb_path_it=file_names,
config=self.config,
Expand Down
Loading