Skip to content

Commit

Permalink
[feat] improve type hinting and add subselecting based on labels
Browse files Browse the repository at this point in the history
  • Loading branch information
kierandidi committed May 30, 2024
1 parent 0ca75eb commit 680d16a
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 5 deletions.
33 changes: 29 additions & 4 deletions graphein/ml/datasets/pdb_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from datetime import datetime
from io import StringIO
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union, Literal

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -36,7 +36,7 @@ def __init__(
split_ratios: Optional[List[float]] = None,
split_time_frames: Optional[List[np.datetime64]] = None,
assign_leftover_rows_to_split_n: int = 0,
labels: Optional[List[str]] = None,
labels: Optional[List[Literal["uniprot_id", "cath_code", "ec_number"]]] = None,
):
"""Instantiate a selection of experimental PDB structures.
Expand All @@ -61,7 +61,7 @@ def __init__(
:type assign_leftover_rows_to_split_n: int, optional
:param labels: A list of names corresponding to metadata labels that should be included in PDB manager dataframe,
defaults to ``None``.
:type labels: Optional[List[str]], optional
:type labels: Optional[List[Literal["uniprot_id", "cath_code", "ec_number"]]], optional
"""
# Arguments
self.root_dir = Path(root_dir)
Expand Down Expand Up @@ -658,7 +658,7 @@ def _parse_ec_number(self) -> Dict[str, str]:
continue
return ec_mapping

def parse(self, labels: List[str]) -> pd.DataFrame:
def parse(self, labels: Optional[List[Literal["uniprot_id", "cath_code", "ec_number"]]] = None) -> pd.DataFrame:
"""Parse all PDB sequence records.
:param labels: A list of names corresponding to metadata labels that should be included in PDB manager dataframe,
Expand Down Expand Up @@ -1286,12 +1286,17 @@ def select_complexes_with_grouped_molecule_types(

def has_uniprot_id(
self,
select_ids: Optional[List[str]] = None,
splits: Optional[List[str]] = None,
update: bool = False,
) -> pd.DataFrame:
"""
Select entries that have a uniprot ID.
:param select_ids: If present, filter for only these IDs. If not present, filter for entries
that have any uniprot ID.
defaults to ``None``.
:type select_ids: Optional[List[str]], optional
:param splits: Names of splits for which to perform the operation,
defaults to ``None``.
:type splits: Optional[List[str]], optional
Expand All @@ -1305,19 +1310,27 @@ def has_uniprot_id(
splits_df = self.get_splits(splits)
df = splits_df.dropna(subset=['uniprot_id'])

if select_ids:
df = df[df['uniprot_id'].isin(select_ids)]

if update:
self.df = df
return df


def has_cath_code(
self,
select_ids: Optional[List[str]] = None,
splits: Optional[List[str]] = None,
update: bool = False,
) -> pd.DataFrame:
"""
Select entries that have a cath code.
:param select_ids: If present, filter for only these CATH codes. If not present, filter for entries
that have any cath code.
defaults to ``None``.
:type select_ids: Optional[List[str]], optional
:param splits: Names of splits for which to perform the operation,
defaults to ``None``.
:type splits: Optional[List[str]], optional
Expand All @@ -1331,18 +1344,27 @@ def has_cath_code(
splits_df = self.get_splits(splits)
df = splits_df.dropna(subset=['cath_code'])

if select_ids:
df = df[df['cath_code'].isin(select_ids)]


if update:
self.df = df
return df

def has_ec_number(
self,
select_ids: Optional[List[str]] = None,
splits: Optional[List[str]] = None,
update: bool = False,
) -> pd.DataFrame:
"""
Select entries that have an EC number.
:param select_ids: If present, filter for only these ec_numbers. If not present, filter for entries
that have any EC number
defaults to ``None``.
:type select_ids: Optional[List[str]], optional
:param splits: Names of splits for which to perform the operation,
defaults to ``None``.
:type splits: Optional[List[str]], optional
Expand All @@ -1356,6 +1378,9 @@ def has_ec_number(
splits_df = self.get_splits(splits)
df = splits_df.dropna(subset=['ec_number'])

if select_ids:
df = df[df['ec_number'].isin(select_ids)]

if update:
self.df = df
return df
Expand Down
2 changes: 1 addition & 1 deletion graphein/protein/tensor/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def protein_df_to_tensor(
"""
num_residues = get_protein_length(df, insertions=insertions)
df = df.loc[df["atom_name"].isin(atoms_to_keep)]
residue_indices = pd.factorize(get_residue_id(df, unique=False))[0]
residue_indices = pd.factorize(pd.Series(get_residue_id(df, unique=False)))[0]
atom_indices = df["atom_name"].map(lambda x: atoms_to_keep.index(x)).values

positions: AtomTensor = (
Expand Down

0 comments on commit 680d16a

Please sign in to comment.