diff --git a/CHANGELOG.md b/CHANGELOG.md index b0351cee..0b308349 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,7 +4,13 @@ * Chain selections are now specified with either `"all"` or a list of strings (e.g. `["A", "B"]`) rather than a single selection string (e.g. `"AB"`). This is a necessary chain due to MMTF support which can have multicharacter chain identifiers. [#307](https://github.com/a-r-j/graphein/pull/307) +#### Bugfixes +* Adds missing `stage` parameter to `graphein.ml.datasets.foldcomp_data.FoldCompDataModule.setup()`. [#310](https://github.com/a-r-j/graphein/pull/310) + #### Other Changes +* Adds entry point for biopandas dataframes in `graphein.protein.tensor.io.protein_to_pyg`. [#310](https://github.com/a-r-j/graphein/pull/310) +* Adds support for `.ent` files to `graphein.protein.graphs.read_pdb_to_dataframe`. [#310](https://github.com/a-r-j/graphein/pull/310) +* Obsolete residues with no replacement are now returned by `graphein.protein.utils.get_obsolete_mapping`. [#310](https://github.com/a-r-j/graphein/pull/310) * Adds the ability to store a dictionary of HETATM positions in `Data`/`Protein` objects created in the `graphein.protein.tensor` module. [#307](https://github.com/a-r-j/graphein/pull/307) * Improved handling of non-standard residues in the `graphein.protein.tensor` module. [#307](https://github.com/a-r-j/graphein/pull/307) * Insertions retained by default in the `graphein.protein.tensor` module. I.e. `insertions=True` is now the default behaviour.[#307](https://github.com/a-r-j/graphein/pull/307) diff --git a/graphein/ml/datasets/foldcomp_dataset.py b/graphein/ml/datasets/foldcomp_dataset.py index 97a6cc06..b4d8c117 100644 --- a/graphein/ml/datasets/foldcomp_dataset.py +++ b/graphein/ml/datasets/foldcomp_dataset.py @@ -311,7 +311,7 @@ def __init__( self.num_workers = num_workers self.pin_memory = pin_memory - def setup(self, stage: str): + def setup(self, stage: Optional[str] = None): self.train_dataset() self.val_dataset() self.test_dataset() diff --git a/graphein/protein/graphs.py b/graphein/protein/graphs.py index 1d063602..41060f0b 100644 --- a/graphein/protein/graphs.py +++ b/graphein/protein/graphs.py @@ -98,12 +98,18 @@ def read_pdb_to_dataframe( if path is not None: if isinstance(path, Path): path = os.fsdecode(path) - if path.endswith(".pdb") or path.endswith(".pdb.gz"): + if ( + path.endswith(".pdb") + or path.endswith(".pdb.gz") + or path.endswith(".ent") + ): atomic_df = PandasPdb().read_pdb(path) elif path.endswith(".mmtf") or path.endswith(".mmtf.gz"): atomic_df = PandasMmtf().read_mmtf(path) else: - raise ValueError(f"File {path} must be either .pdb or .mmtf not") + raise ValueError( + f"File {path} must be either .pdb(.gz), .mmtf(.gz) or .ent, not {path.split('.')[-1]}" + ) elif uniprot_id is not None: atomic_df = PandasPdb().fetch_pdb( uniprot_id=uniprot_id, source="alphafold2-v3" diff --git a/graphein/protein/tensor/io.py b/graphein/protein/tensor/io.py index 37500408..cc24c6d2 100644 --- a/graphein/protein/tensor/io.py +++ b/graphein/protein/tensor/io.py @@ -98,6 +98,7 @@ def protein_to_pyg( path: Optional[Union[str, os.PathLike]] = None, pdb_code: Optional[str] = None, uniprot_id: Optional[str] = None, + df: Optional[pd.DataFrame] = None, chain_selection: Union[str, List[str]] = "all", deprotonate: bool = True, keep_insertions: bool = True, @@ -182,14 +183,15 @@ def protein_to_pyg( else uniprot_id ) else: - raise ValueError("Must provide either a path, PDB code or uniprot ID.") - - df = read_pdb_to_dataframe( - path=path, - pdb_code=pdb_code, - uniprot_id=uniprot_id, - model_index=model_index, - ) + id = None + + if df is None: + df = read_pdb_to_dataframe( + path=path, + pdb_code=pdb_code, + uniprot_id=uniprot_id, + model_index=model_index, + ) if chain_selection != "all": if isinstance(chain_selection, str): chain_selection = [chain_selection] diff --git a/graphein/protein/utils.py b/graphein/protein/utils.py index 79d381d9..b2f1843b 100644 --- a/graphein/protein/utils.py +++ b/graphein/protein/utils.py @@ -54,6 +54,8 @@ def get_obsolete_mapping() -> Dict[str, str]: obs_dict[entry[2].lower().decode("utf-8")] = ( entry[3].lower().decode("utf-8") ) + elif len(entry) == 3: + obs_dict[entry[2].lower().decode("utf-8")] = "" return obs_dict