diff --git a/CHANGELOG.md b/CHANGELOG.md index 0b308349..9db136bd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ #### Bugfixes * Adds missing `stage` parameter to `graphein.ml.datasets.foldcomp_data.FoldCompDataModule.setup()`. [#310](https://github.com/a-r-j/graphein/pull/310) +* Fixes incorrect jaxtyping syntax for variable size dimensions [#312](https://github.com/a-r-j/graphein/pull/312) #### Other Changes * Adds entry point for biopandas dataframes in `graphein.protein.tensor.io.protein_to_pyg`. [#310](https://github.com/a-r-j/graphein/pull/310) @@ -14,7 +15,7 @@ * Adds the ability to store a dictionary of HETATM positions in `Data`/`Protein` objects created in the `graphein.protein.tensor` module. [#307](https://github.com/a-r-j/graphein/pull/307) * Improved handling of non-standard residues in the `graphein.protein.tensor` module. [#307](https://github.com/a-r-j/graphein/pull/307) * Insertions retained by default in the `graphein.protein.tensor` module. I.e. `insertions=True` is now the default behaviour.[#307](https://github.com/a-r-j/graphein/pull/307) - +* Adds transform composition to FoldComp Dataset [#312](https://github.com/a-r-j/graphein/pull/312) ### 1.7.0 - UNRELEASED diff --git a/graphein/ml/datasets/foldcomp_dataset.py b/graphein/ml/datasets/foldcomp_dataset.py index b4d8c117..8320e517 100644 --- a/graphein/ml/datasets/foldcomp_dataset.py +++ b/graphein/ml/datasets/foldcomp_dataset.py @@ -6,17 +6,17 @@ # Code Repository: https://github.com/a-r-j/graphein import asyncio -import contextlib import os import random import shutil from pathlib import Path -from typing import Callable, List, Optional, Union +from typing import Callable, Dict, Iterable, List, Optional, Union import pandas as pd from biopandas.pdb import PandasPdb from loguru import logger as log from sklearn.model_selection import train_test_split +from torch_geometric import transforms as T from torch_geometric.data import Data, Dataset from torch_geometric.loader import DataLoader from tqdm import tqdm @@ -76,7 +76,7 @@ def __init__( exclude_ids: Optional[List[str]] = None, fraction: float = 1.0, use_graphein: bool = True, - transform: Optional[List[GraphTransform]] = None, + transform: Optional[T.BaseTransform] = None, ): """Dataset class for FoldComp databases. @@ -124,7 +124,7 @@ def __init__( ] self._get_indices() super().__init__( - root=self.root, transform=None, pre_transform=None # type: ignore + root=self.root, transform=self.transform, pre_transform=None # type: ignore ) @property @@ -232,14 +232,7 @@ def get(self, idx) -> Union[Data, Protein]: idx = self.protein_to_idx[idx] name, pdb = self.db[idx] - out = self.process_pdb(pdb, name) - - # Apply transforms, if any - if self.transform is not None: - for transform in self.transform: - out = transform(out) - - return out + return self.process_pdb(pdb, name) class FoldCompLightningDataModule(L.LightningDataModule): @@ -252,7 +245,7 @@ def __init__( train_split: Optional[Union[List[str], float]] = None, val_split: Optional[Union[List[str], float]] = None, test_split: Optional[Union[List[str], float]] = None, - transform: Optional[List[GraphTransform]] = None, + transform: Optional[Iterable[Callable]] = None, num_workers: int = 4, pin_memory: bool = True, ) -> None: @@ -281,7 +274,7 @@ def __init__( ``Data``/``Protein`` object and return a transformed version. The data object will be transformed before every access. (default: ``None``). - :type transform: Optional[List[GraphTransform]] + :type transform: Optional[Iterable[Callable]] :param num_workers: Number of workers to use for data loading, defaults to ``4``. :type num_workers: int, optional @@ -297,7 +290,12 @@ def __init__( self.train_split = train_split self.val_split = val_split self.test_split = test_split - self.transform = transform + self.transform = ( + self._compose_transforms(transform) + if transform is not None + else None + ) + if ( isinstance(train_split, float) and isinstance(val_split, float) @@ -311,6 +309,12 @@ def __init__( self.num_workers = num_workers self.pin_memory = pin_memory + def _compose_transforms(self, transforms: Iterable[Callable]) -> T.Compose: + try: + return T.Compose(list(transforms.values())) + except Exception: + return T.Compose(transforms) + def setup(self, stage: Optional[str] = None): self.train_dataset() self.val_dataset() diff --git a/graphein/protein/tensor/types.py b/graphein/protein/tensor/types.py index 5bfb5521..4559d3ca 100644 --- a/graphein/protein/tensor/types.py +++ b/graphein/protein/tensor/types.py @@ -9,11 +9,11 @@ # Code Repository: https://github.com/a-r-j/graphein from typing import NewType, Optional, Union -from jaxtyping import Float +from jaxtyping import Float, Int from torch import Tensor # Positions -AtomTensor = NewType("AtomTensor", Float[Tensor, "-1 37 3"]) +AtomTensor = NewType("AtomTensor", Float[Tensor, "residues 37 3"]) """ ``torch.float[-1, 37, 3]`` @@ -24,7 +24,7 @@ .. seealso:: :class:`ResidueTensor` :class:`CoordTensor` """ -BackboneTensor = NewType("BackboneTensor", Float[Tensor, "-1 4 3"]) +BackboneTensor = NewType("BackboneTensor", Float[Tensor, "residues 4 3"]) """ ``torch.float[-1, 4, 3]`` @@ -49,7 +49,7 @@ """ -CoordTensor = NewType("CoordTensor", Float[Tensor, "-1 3"]) +CoordTensor = NewType("CoordTensor", Float[Tensor, "nodes 3"]) """ ``torch.float[-1, 3]`` @@ -68,7 +68,9 @@ """ # Represenations -BackboneFrameTensor = NewType("BackboneFrameTensor", Float[Tensor, "-1 3 3"]) +BackboneFrameTensor = NewType( + "BackboneFrameTensor", Float[Tensor, "residues 3 3"] +) """ ``torch.float[-1, 3, 3]`` @@ -89,9 +91,9 @@ # Rotations -EulerAngleTensor = NewType("EulerAngleTensor", Float[Tensor, "-1 3"]) +EulerAngleTensor = NewType("EulerAngleTensor", Float[Tensor, "nodes 3"]) -QuaternionTensor = NewType("QuaternionTensor", Float[Tensor, "-1 4"]) +QuaternionTensor = NewType("QuaternionTensor", Float[Tensor, "nodes 4"]) """ ``torch.float[-1, 4]`` @@ -102,7 +104,7 @@ """ -TransformTensor = NewType("TransformTensor", Float[Tensor, "-1 4 4"]) +TransformTensor = NewType("TransformTensor", Float[Tensor, "nodes 4 4"]) RotationMatrix2D = NewType("RotationMatrix2D", Float[Tensor, "2 2"]) @@ -135,7 +137,9 @@ """ -RotationMatrixTensor = NewType("RotationMatrixTensor", Float[Tensor, "-1 3 3"]) +RotationMatrixTensor = NewType( + "RotationMatrixTensor", Float[Tensor, "nodes 3 3"] +) RotationTensor = NewType( "RotationTensor", Union[QuaternionTensor, RotationMatrixTensor] @@ -144,7 +148,8 @@ # Angles DihedralTensor = NewType( - "DihedralTensor", Union[Float[Tensor, "-1 3"], Float[Tensor, "-1 6"]] + "DihedralTensor", + Union[Float[Tensor, "residues 3"], Float[Tensor, "residues 6"]], ) """ ``Union[torch.float[-1, 3], torch.float[-1, 6]]`` @@ -161,7 +166,8 @@ """ TorsionTensor = NewType( - "TorsionTensor", Union[Float[Tensor, "-1 4"], Float[Tensor, "-1 8"]] + "TorsionTensor", + Union[Float[Tensor, "residues 4"], Float[Tensor, "residues 8"]], ) """ ``Union[torch.float[-1, 4], torch.float[-1, 8]]`` @@ -177,7 +183,9 @@ """ -BackboneFrameTensor = NewType("BackboneFrameTensor", Float[Tensor, "-1 3 3"]) +BackboneFrameTensor = NewType( + "BackboneFrameTensor", Float[Tensor, "residues 3 3"] +) """ ``torch.float[-1, 3, 3]`` @@ -198,12 +206,12 @@ .. seealso:: :class:`BackboneFrameTensor` """ -EdgeTensor = NewType("EdgeTensor", Float[Tensor, "2 -1"]) +EdgeTensor = NewType("EdgeTensor", Int[Tensor, "2 edges"]) -OrientationTensor = NewType("OrientationTensor", Float[Tensor, "-1 2 3"]) +OrientationTensor = NewType("OrientationTensor", Float[Tensor, "nodes 2 3"]) -ScalarTensor = NewType("ScalarTensor", Float[Tensor, "-1"]) +ScalarTensor = NewType("ScalarTensor", Float[Tensor, "nodes"]) OptTensor = NewType("OptTensor", Optional[Tensor])