Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix jaxtyping syntax error #312

Merged
merged 28 commits into from
Apr 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
176d884
add PSW to nonstandard residues
a-r-j Apr 17, 2023
fa89a37
improve insertion and non-standard residue handling
a-r-j Apr 17, 2023
9855b9b
refactor chain selection
a-r-j Apr 17, 2023
f143719
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 17, 2023
3f3b3d9
remove unused verbosity arg
a-r-j Apr 17, 2023
09f05e5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 17, 2023
b7475df
fix chain selection in tests
a-r-j Apr 17, 2023
2e0a371
Merge branch 'tensor_fixes' of https://www.github.com/a-r-j/graphein …
a-r-j Apr 17, 2023
d2c1808
fix chain selection in tutorial notebook
a-r-j Apr 17, 2023
fc332c6
fix notebook chain selection
a-r-j Apr 17, 2023
4a67851
fix chain selection typehint
a-r-j Apr 17, 2023
5f648d2
Update changelog
a-r-j Apr 17, 2023
ab26d78
Add NLW to non-standard residues
a-r-j Apr 17, 2023
a449bba
Merge branch 'tensor_fixes' of https://www.github.com/a-r-j/graphein …
a-r-j Apr 17, 2023
afc0f8b
add .ent support
a-r-j Apr 20, 2023
258c94d
add entry for construction from dataframe
a-r-j Apr 20, 2023
c9856ae
add missing stage arg
a-r-j Apr 20, 2023
9e1191a
improve obsolete mapping retrieving to include entries with no replac…
a-r-j Apr 20, 2023
17c38ab
Merge branch 'master' into tensor_fixes
a-r-j Apr 20, 2023
7bf4ff3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 20, 2023
5af9e06
update changelog
a-r-j Apr 21, 2023
e00bdfb
add transforms to foldcomp datasets
a-r-j Apr 22, 2023
31018bc
fix jaxtyping syntax
a-r-j Apr 25, 2023
6e26455
Merge branch 'tensor_fixes' of https://www.github.com/a-r-j/graphein …
a-r-j Apr 25, 2023
3681714
Merge branch 'master' into tensor_fixes
a-r-j Apr 27, 2023
adbdbe1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 27, 2023
50ac31b
Update changelog
a-r-j Apr 27, 2023
088ae02
fix double application of transforms
a-r-j Apr 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#### Bugfixes
* Adds missing `stage` parameter to `graphein.ml.datasets.foldcomp_data.FoldCompDataModule.setup()`. [#310](https://github.com/a-r-j/graphein/pull/310)
* Fixes incorrect jaxtyping syntax for variable size dimensions [#312](https://github.com/a-r-j/graphein/pull/312)

#### Other Changes
* Adds entry point for biopandas dataframes in `graphein.protein.tensor.io.protein_to_pyg`. [#310](https://github.com/a-r-j/graphein/pull/310)
Expand All @@ -14,7 +15,7 @@
* Adds the ability to store a dictionary of HETATM positions in `Data`/`Protein` objects created in the `graphein.protein.tensor` module. [#307](https://github.com/a-r-j/graphein/pull/307)
* Improved handling of non-standard residues in the `graphein.protein.tensor` module. [#307](https://github.com/a-r-j/graphein/pull/307)
* Insertions retained by default in the `graphein.protein.tensor` module. I.e. `insertions=True` is now the default behaviour.[#307](https://github.com/a-r-j/graphein/pull/307)

* Adds transform composition to FoldComp Dataset [#312](https://github.com/a-r-j/graphein/pull/312)


### 1.7.0 - UNRELEASED
Expand Down
34 changes: 19 additions & 15 deletions graphein/ml/datasets/foldcomp_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@
# Code Repository: https://github.com/a-r-j/graphein

import asyncio
import contextlib
import os
import random
import shutil
from pathlib import Path
from typing import Callable, List, Optional, Union
from typing import Callable, Dict, Iterable, List, Optional, Union

import pandas as pd
from biopandas.pdb import PandasPdb
from loguru import logger as log
from sklearn.model_selection import train_test_split
from torch_geometric import transforms as T
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader
from tqdm import tqdm
Expand Down Expand Up @@ -76,7 +76,7 @@ def __init__(
exclude_ids: Optional[List[str]] = None,
fraction: float = 1.0,
use_graphein: bool = True,
transform: Optional[List[GraphTransform]] = None,
transform: Optional[T.BaseTransform] = None,
):
"""Dataset class for FoldComp databases.

Expand Down Expand Up @@ -124,7 +124,7 @@ def __init__(
]
self._get_indices()
super().__init__(
root=self.root, transform=None, pre_transform=None # type: ignore
root=self.root, transform=self.transform, pre_transform=None # type: ignore
)

@property
Expand Down Expand Up @@ -232,14 +232,7 @@ def get(self, idx) -> Union[Data, Protein]:
idx = self.protein_to_idx[idx]
name, pdb = self.db[idx]

out = self.process_pdb(pdb, name)

# Apply transforms, if any
if self.transform is not None:
for transform in self.transform:
out = transform(out)

return out
return self.process_pdb(pdb, name)


class FoldCompLightningDataModule(L.LightningDataModule):
Expand All @@ -252,7 +245,7 @@ def __init__(
train_split: Optional[Union[List[str], float]] = None,
val_split: Optional[Union[List[str], float]] = None,
test_split: Optional[Union[List[str], float]] = None,
transform: Optional[List[GraphTransform]] = None,
transform: Optional[Iterable[Callable]] = None,
num_workers: int = 4,
pin_memory: bool = True,
) -> None:
Expand Down Expand Up @@ -281,7 +274,7 @@ def __init__(
``Data``/``Protein`` object and return a transformed version.
The data object will be transformed before every access.
(default: ``None``).
:type transform: Optional[List[GraphTransform]]
:type transform: Optional[Iterable[Callable]]
:param num_workers: Number of workers to use for data loading, defaults
to ``4``.
:type num_workers: int, optional
Expand All @@ -297,7 +290,12 @@ def __init__(
self.train_split = train_split
self.val_split = val_split
self.test_split = test_split
self.transform = transform
self.transform = (
self._compose_transforms(transform)
if transform is not None
else None
)

if (
isinstance(train_split, float)
and isinstance(val_split, float)
Expand All @@ -311,6 +309,12 @@ def __init__(
self.num_workers = num_workers
self.pin_memory = pin_memory

def _compose_transforms(self, transforms: Iterable[Callable]) -> T.Compose:
try:
return T.Compose(list(transforms.values()))
except Exception:
return T.Compose(transforms)

def setup(self, stage: Optional[str] = None):
self.train_dataset()
self.val_dataset()
Expand Down
38 changes: 23 additions & 15 deletions graphein/protein/tensor/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
# Code Repository: https://github.com/a-r-j/graphein
from typing import NewType, Optional, Union

from jaxtyping import Float
from jaxtyping import Float, Int
from torch import Tensor

# Positions
AtomTensor = NewType("AtomTensor", Float[Tensor, "-1 37 3"])
AtomTensor = NewType("AtomTensor", Float[Tensor, "residues 37 3"])
"""
``torch.float[-1, 37, 3]``

Expand All @@ -24,7 +24,7 @@
.. seealso:: :class:`ResidueTensor` :class:`CoordTensor`
"""

BackboneTensor = NewType("BackboneTensor", Float[Tensor, "-1 4 3"])
BackboneTensor = NewType("BackboneTensor", Float[Tensor, "residues 4 3"])
"""
``torch.float[-1, 4, 3]``

Expand All @@ -49,7 +49,7 @@
"""


CoordTensor = NewType("CoordTensor", Float[Tensor, "-1 3"])
CoordTensor = NewType("CoordTensor", Float[Tensor, "nodes 3"])
"""
``torch.float[-1, 3]``

Expand All @@ -68,7 +68,9 @@
"""

# Represenations
BackboneFrameTensor = NewType("BackboneFrameTensor", Float[Tensor, "-1 3 3"])
BackboneFrameTensor = NewType(
"BackboneFrameTensor", Float[Tensor, "residues 3 3"]
)
"""
``torch.float[-1, 3, 3]``

Expand All @@ -89,9 +91,9 @@


# Rotations
EulerAngleTensor = NewType("EulerAngleTensor", Float[Tensor, "-1 3"])
EulerAngleTensor = NewType("EulerAngleTensor", Float[Tensor, "nodes 3"])

QuaternionTensor = NewType("QuaternionTensor", Float[Tensor, "-1 4"])
QuaternionTensor = NewType("QuaternionTensor", Float[Tensor, "nodes 4"])
"""
``torch.float[-1, 4]``

Expand All @@ -102,7 +104,7 @@
"""


TransformTensor = NewType("TransformTensor", Float[Tensor, "-1 4 4"])
TransformTensor = NewType("TransformTensor", Float[Tensor, "nodes 4 4"])


RotationMatrix2D = NewType("RotationMatrix2D", Float[Tensor, "2 2"])
Expand Down Expand Up @@ -135,7 +137,9 @@
"""


RotationMatrixTensor = NewType("RotationMatrixTensor", Float[Tensor, "-1 3 3"])
RotationMatrixTensor = NewType(
"RotationMatrixTensor", Float[Tensor, "nodes 3 3"]
)

RotationTensor = NewType(
"RotationTensor", Union[QuaternionTensor, RotationMatrixTensor]
Expand All @@ -144,7 +148,8 @@

# Angles
DihedralTensor = NewType(
"DihedralTensor", Union[Float[Tensor, "-1 3"], Float[Tensor, "-1 6"]]
"DihedralTensor",
Union[Float[Tensor, "residues 3"], Float[Tensor, "residues 6"]],
)
"""
``Union[torch.float[-1, 3], torch.float[-1, 6]]``
Expand All @@ -161,7 +166,8 @@
"""

TorsionTensor = NewType(
"TorsionTensor", Union[Float[Tensor, "-1 4"], Float[Tensor, "-1 8"]]
"TorsionTensor",
Union[Float[Tensor, "residues 4"], Float[Tensor, "residues 8"]],
)
"""
``Union[torch.float[-1, 4], torch.float[-1, 8]]``
Expand All @@ -177,7 +183,9 @@

"""

BackboneFrameTensor = NewType("BackboneFrameTensor", Float[Tensor, "-1 3 3"])
BackboneFrameTensor = NewType(
"BackboneFrameTensor", Float[Tensor, "residues 3 3"]
)
"""
``torch.float[-1, 3, 3]``

Expand All @@ -198,12 +206,12 @@
.. seealso:: :class:`BackboneFrameTensor`
"""

EdgeTensor = NewType("EdgeTensor", Float[Tensor, "2 -1"])
EdgeTensor = NewType("EdgeTensor", Int[Tensor, "2 edges"])


OrientationTensor = NewType("OrientationTensor", Float[Tensor, "-1 2 3"])
OrientationTensor = NewType("OrientationTensor", Float[Tensor, "nodes 2 3"])


ScalarTensor = NewType("ScalarTensor", Float[Tensor, "-1"])
ScalarTensor = NewType("ScalarTensor", Float[Tensor, "nodes"])

OptTensor = NewType("OptTensor", Optional[Tensor])