diff --git a/CHANGELOG.md b/CHANGELOG.md index 797a32e4..71ed5a4b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ * Insertions retained by default in the `graphein.protein.tensor` module. I.e. `insertions=True` is now the default behaviour.[#307](https://github.com/a-r-j/graphein/pull/307) * Adds transform composition to FoldComp Dataset [#312](https://github.com/a-r-j/graphein/pull/312) * Improve FoldComp dataloading performance and include B factors (pLDDT) in output. [#313](https://github.com/a-r-j/graphein/pull/313) [#315](https://github.com/a-r-j/graphein/pull/315) +* Add new helper functions to PDBManager [#322](https://github.com/a-r-j/graphein/pull/322) (@amorehead) ### 1.7.0 - UNRELEASED diff --git a/graphein/ml/datasets/pdb_data.py b/graphein/ml/datasets/pdb_data.py index d7ba7dcb..9082746a 100644 --- a/graphein/ml/datasets/pdb_data.py +++ b/graphein/ml/datasets/pdb_data.py @@ -728,6 +728,64 @@ def experiment_type( self.df = df return df + def experiment_types( + self, + types: List[str] = ["diffraction"], + splits: Optional[List[str]] = None, + update: bool = False, + ) -> pd.DataFrame: + """ + Select molecules by experiment types: + [``diffraction``, ``NMR``, ``EM``, ``other``] + + :param types: Experiment types of molecules, defaults to "diffraction". + :type types: List[str], optional + :param splits: Names of splits for which to perform the operation, + defaults to ``None``. + :type splits: Optional[List[str]], optional + :param update: Whether to modify the DataFrame in place, defaults to + ``False``. + :type update: bool, optional + + :return: DataFrame of selected molecules. + :rtype: pd.DataFrame + """ + splits_df = self.get_splits(splits) + df = splits_df.loc[splits_df.experiment_type.isin(types)] + + if update: + self.df = df + return df + + def name( + self, + substrings: List[str], + splits: Optional[List[str]] = None, + update: bool = False, + ) -> pd.DataFrame: + """ + Select molecules by substrings present in their names: + e.g., [``DNA``, ``RNA``] + + :param substrings: Substrings to be found within the name field of each molecule. + :type type: str, optional + :param splits: Names of splits for which to perform the operation, + defaults to ``None``. + :type splits: Optional[List[str]], optional + :param update: Whether to modify the DataFrame in place, defaults to + ``False``. + :type update: bool, optional + + :return: DataFrame of selected molecules. + :rtype: pd.DataFrame + """ + splits_df = self.get_splits(splits) + df = splits_df.loc[splits_df.name.str.contains("|".join(substrings))] + + if update: + self.df = df + return df + def compare_length( self, length: int, @@ -1056,6 +1114,41 @@ def remove_non_standard_alphabet_sequences( self.df = df return df + def select_complexes_with_grouped_molecule_types( + self, + molecule_types_to_group: List[str], + splits: Optional[List[str]] = None, + update: bool = False, + ): + """ + Select complexes containing at least one instance of each + provided molecule type. + + :param molecule_types_to_group: Names of molecule types by which to assemble complexes. + :type molecule_types_to_group: List[str] + :param splits: Names of splits for which to perform the operation, + defaults to ``None``. + :type splits: Optional[List[str]], optional + :param update: Whether to update the DataFrame in place, defaults to + ``False``. + :type update: bool, optional + + :return: DataFrame containing only complexes with at least one instance + of each provided molecule type. + :rtype: pd.DataFrame + """ + splits_df = self.get_splits(splits) + df = splits_df.groupby("pdb").filter( + lambda group: all( + [ + molecule_type_to_group in group["molecule_type"].values + for molecule_type_to_group in molecule_types_to_group + ] + ) + ) + if update: + self.df = df + def split_df_proportionally( self, df: pd.DataFrame, @@ -1692,7 +1785,11 @@ def merge_pdb_chain_groups(self, group: DataFrameGroupBy) -> pd.DataFrame: ) def select_pdb_by_criterion( - self, pdb: PandasPdb, field: str, field_values: List[Any] + self, + pdb: PandasPdb, + field: str, + field_values: List[Any], + pdb_code: str, ) -> PandasPdb: """Filter a PDB using a field selection. @@ -1703,6 +1800,8 @@ def select_pdb_by_criterion( :param field_values: The field values by which to filter the PDB. :type field_values: List[Any] + :param pdb_code: The PDB code associated with a given PDB object. + :type pdb_code: str :return: The filtered PDB object. :rtype: PandasPdb @@ -1712,10 +1811,10 @@ def select_pdb_by_criterion( filtered_pdb = pdb.df[key][ pdb.df[key][field].isin(field_values) ] - if "ATOM" in key: - assert ( - len(filtered_pdb) > 0 - ), "Filtered DataFrame must contain atoms." + if "ATOM" in key and len(filtered_pdb) == 0: + log.warning( + f"DataFrame for PDB {pdb_code} does not contain any standard atoms after filtering" + ) pdb.df[key] = filtered_pdb return pdb @@ -1727,7 +1826,7 @@ def write_out_pdb_chain_groups( split: str, merge_fn: Callable, atom_df_name: str = "ATOM", - max_num_chains_per_pdb_code: int = 1, + max_num_chains_per_pdb_code: int = -1, models: List[int] = [1], ): """Record groups of PDB codes and associated chains @@ -1748,7 +1847,7 @@ def write_out_pdb_chain_groups( ATOM entries within a PandasPdb object. :type atom_df_name: str, defaults to ``ATOM`` :param max_num_chains_per_pdb_code: Maximum number of chains - to collate into a matching PDB file. + to collate into a matching PDB file, defaults to ``-1``. :type max_num_chains_per_pdb_code: int, optional :param models: List of indices of models from which to extract chains, defaults to ``[1]``. @@ -1803,7 +1902,7 @@ def write_out_pdb_chain_groups( else chains[:max_num_chains_per_pdb_code] ) pdb_chains = self.select_pdb_by_criterion( - pdb, "chain_id", chains + pdb, "chain_id", chains, entry_pdb_code ) # export selected chains within the same PDB file pdb_chains.to_pdb(str(output_pdb_filepath)) @@ -1814,7 +1913,7 @@ def write_df_pdbs( df: pd.DataFrame, out_dir: str = "collated_pdb", splits: Optional[List[str]] = None, - max_num_chains_per_pdb_code: int = 1, + max_num_chains_per_pdb_code: int = -1, models: List[int] = [1], ): """Write the given selection as a collection of PDB files. @@ -1831,7 +1930,7 @@ def write_df_pdbs( defaults to ``None``. :type splits: Optional[List[str]], optional :param max_num_chains_per_pdb_code: Maximum number of chains - to collate into a matching PDB file. + to collate into a matching PDB file, defaults to ``-1``. :type max_num_chains_per_pdb_code: int, optional :param models: List of indices of models from which to extract chains, defaults to ``[1]``. @@ -1867,7 +1966,7 @@ def export_pdbs( self, pdb_dir: str, splits: Optional[List[str]] = None, - max_num_chains_per_pdb_code: int = 1, + max_num_chains_per_pdb_code: int = -1, models: List[int] = [1], force: bool = False, ): @@ -1879,7 +1978,7 @@ def export_pdbs( defaults to ``None``. :type splits: Optional[List[str]], optional :param max_num_chains_per_pdb_code: Maximum number of chains - to collate into a matching PDB file. + to collate into a matching PDB file, defaults to ``-1``. :type max_num_chains_per_pdb_code: int, optional :param models: List of indices of models from which to extract chains, defaults to ``[1]``.