From 8c821f0cc1aff959fe4aad7f224b7c569c75017c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 18 Aug 2023 20:26:23 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- graphein/ml/datasets/pdb_data.py | 161 +++++++++++++++++++++++-------- 1 file changed, 123 insertions(+), 38 deletions(-) diff --git a/graphein/ml/datasets/pdb_data.py b/graphein/ml/datasets/pdb_data.py index c1c82612..24065aec 100644 --- a/graphein/ml/datasets/pdb_data.py +++ b/graphein/ml/datasets/pdb_data.py @@ -27,8 +27,28 @@ PRIMARY_INTERCHAIN_CONTACT_ATOMS_FOR_FILTERING: List[str] = ["CA", "C4'"] SECONDARY_INTERCHAIN_CONTACT_ATOMS_NOT_FOR_FILTERING: List[str] = ["H"] -PRIMARY_HYDROGEN_BOND_ATOMS_FOR_FILTERING: List[str] = ["N", "O", "N1", "N9", "N3", "C2", "C4", "C5", "C6"] -SECONDARY_HYDROGEN_BOND_ATOMS_FOR_FILTERING: List[str] = ["N", "O", "N1", "N9", "N3", "C2", "C4", "C5", "C6"] +PRIMARY_HYDROGEN_BOND_ATOMS_FOR_FILTERING: List[str] = [ + "N", + "O", + "N1", + "N9", + "N3", + "C2", + "C4", + "C5", + "C6", +] +SECONDARY_HYDROGEN_BOND_ATOMS_FOR_FILTERING: List[str] = [ + "N", + "O", + "N1", + "N9", + "N3", + "C2", + "C4", + "C5", + "C6", +] class PDBManager: @@ -1828,10 +1848,18 @@ def select_pdb_by_criterion( def filter_chains_by_interface_criteria( self, pdb: PandasPdb, - primary_interchain_contact_atoms_for_filtering: List[str] = PRIMARY_INTERCHAIN_CONTACT_ATOMS_FOR_FILTERING, - secondary_interchain_contact_atoms_not_for_filtering: List[str] = SECONDARY_INTERCHAIN_CONTACT_ATOMS_NOT_FOR_FILTERING, - primary_hydrogen_bond_atoms_for_filtering: List[str] = PRIMARY_HYDROGEN_BOND_ATOMS_FOR_FILTERING, - secondary_hydrogen_bond_atoms_for_filtering: List[str] = SECONDARY_HYDROGEN_BOND_ATOMS_FOR_FILTERING, + primary_interchain_contact_atoms_for_filtering: List[ + str + ] = PRIMARY_INTERCHAIN_CONTACT_ATOMS_FOR_FILTERING, + secondary_interchain_contact_atoms_not_for_filtering: List[ + str + ] = SECONDARY_INTERCHAIN_CONTACT_ATOMS_NOT_FOR_FILTERING, + primary_hydrogen_bond_atoms_for_filtering: List[ + str + ] = PRIMARY_HYDROGEN_BOND_ATOMS_FOR_FILTERING, + secondary_hydrogen_bond_atoms_for_filtering: List[ + str + ] = SECONDARY_HYDROGEN_BOND_ATOMS_FOR_FILTERING, interface_contact_criterion: float = 7.0, hydrogen_bond_criterion: float = 3.5, interface_contact_count: int = 16, @@ -1880,42 +1908,97 @@ def filter_chains_by_interface_criteria( :rtype: PandasPdb """ filtered_pdb = copy.deepcopy(pdb) - + atom_data = pdb.df[atom_df_name] unique_chain_ids = atom_data[chain_id_col].unique() - - interface_contact_atom_mask = atom_data[atom_name_col].isin(primary_interchain_contact_atoms_for_filtering) - interface_contact_other_atom_mask = ~atom_data[atom_name_col].isin(secondary_interchain_contact_atoms_not_for_filtering) - hydrogen_bond_atom_mask = atom_data[atom_name_col].isin(primary_hydrogen_bond_atoms_for_filtering) - hydrogen_bond_other_atom_mask = atom_data[atom_name_col].isin(secondary_hydrogen_bond_atoms_for_filtering) - + + interface_contact_atom_mask = atom_data[atom_name_col].isin( + primary_interchain_contact_atoms_for_filtering + ) + interface_contact_other_atom_mask = ~atom_data[atom_name_col].isin( + secondary_interchain_contact_atoms_not_for_filtering + ) + hydrogen_bond_atom_mask = atom_data[atom_name_col].isin( + primary_hydrogen_bond_atoms_for_filtering + ) + hydrogen_bond_other_atom_mask = atom_data[atom_name_col].isin( + secondary_hydrogen_bond_atoms_for_filtering + ) + for chain1 in unique_chain_ids: - interface_contact_chain1_mask = (atom_data[chain_id_col] == chain1) & interface_contact_atom_mask - hydrogen_bond_chain1_mask = (atom_data[chain_id_col] == chain1) & hydrogen_bond_atom_mask - interface_contact_chain1_residues = atom_data[interface_contact_chain1_mask] + interface_contact_chain1_mask = ( + atom_data[chain_id_col] == chain1 + ) & interface_contact_atom_mask + hydrogen_bond_chain1_mask = ( + atom_data[chain_id_col] == chain1 + ) & hydrogen_bond_atom_mask + interface_contact_chain1_residues = atom_data[ + interface_contact_chain1_mask + ] hydrogen_bond_chain1_atoms = atom_data[hydrogen_bond_chain1_mask] - - if np.sum(interface_contact_chain1_mask) == 0 or np.sum(hydrogen_bond_chain1_mask) == 0: + + if ( + np.sum(interface_contact_chain1_mask) == 0 + or np.sum(hydrogen_bond_chain1_mask) == 0 + ): continue - - interface_contact_chain1_coords = interface_contact_chain1_residues[["x_coord", "y_coord", "z_coord"]].to_numpy() - interface_contact_non_chain1_coords = atom_data.loc[interface_contact_other_atom_mask & (atom_data[chain_id_col] != chain1), ["x_coord", "y_coord", "z_coord"]].to_numpy() - hydrogen_bond_chain1_coords = hydrogen_bond_chain1_atoms[["x_coord", "y_coord", "z_coord"]].to_numpy() - hydrogen_bond_non_chain1_coords = atom_data.loc[hydrogen_bond_other_atom_mask & (atom_data[chain_id_col] != chain1), ["x_coord", "y_coord", "z_coord"]].to_numpy() - - interface_contact_distances = cdist(interface_contact_chain1_coords, interface_contact_non_chain1_coords, metric="euclidean") - hydrogen_bond_distances = cdist(hydrogen_bond_chain1_coords, hydrogen_bond_non_chain1_coords, metric="euclidean") - - num_interface_contacts = np.sum(interface_contact_distances <= interface_contact_criterion, axis=1).sum() - chain_within_interface = (num_interface_contacts >= interface_contact_count).item() - - num_hydrogen_bonds = np.sum(hydrogen_bond_distances <= hydrogen_bond_criterion, axis=1).sum() - chain_with_sufficient_bond_count = (num_hydrogen_bonds >= hydrogen_bond_count).item() - - if not chain_within_interface or not chain_with_sufficient_bond_count: - log.info(f"Filtering out chain {chain1} within PDB {pdb.pdb_path}, as it contains {num_interface_contacts} (of {interface_contact_count} required) interface contacts and {num_hydrogen_bonds} (of {hydrogen_bond_count} required) hydrogen bonds") - filtered_pdb.df[atom_df_name] = filtered_pdb.df[atom_df_name][filtered_pdb.df[atom_df_name][chain_id_col] != chain1] - + + interface_contact_chain1_coords = ( + interface_contact_chain1_residues[ + ["x_coord", "y_coord", "z_coord"] + ].to_numpy() + ) + interface_contact_non_chain1_coords = atom_data.loc[ + interface_contact_other_atom_mask + & (atom_data[chain_id_col] != chain1), + ["x_coord", "y_coord", "z_coord"], + ].to_numpy() + hydrogen_bond_chain1_coords = hydrogen_bond_chain1_atoms[ + ["x_coord", "y_coord", "z_coord"] + ].to_numpy() + hydrogen_bond_non_chain1_coords = atom_data.loc[ + hydrogen_bond_other_atom_mask + & (atom_data[chain_id_col] != chain1), + ["x_coord", "y_coord", "z_coord"], + ].to_numpy() + + interface_contact_distances = cdist( + interface_contact_chain1_coords, + interface_contact_non_chain1_coords, + metric="euclidean", + ) + hydrogen_bond_distances = cdist( + hydrogen_bond_chain1_coords, + hydrogen_bond_non_chain1_coords, + metric="euclidean", + ) + + num_interface_contacts = np.sum( + interface_contact_distances <= interface_contact_criterion, + axis=1, + ).sum() + chain_within_interface = ( + num_interface_contacts >= interface_contact_count + ).item() + + num_hydrogen_bonds = np.sum( + hydrogen_bond_distances <= hydrogen_bond_criterion, axis=1 + ).sum() + chain_with_sufficient_bond_count = ( + num_hydrogen_bonds >= hydrogen_bond_count + ).item() + + if ( + not chain_within_interface + or not chain_with_sufficient_bond_count + ): + log.info( + f"Filtering out chain {chain1} within PDB {pdb.pdb_path}, as it contains {num_interface_contacts} (of {interface_contact_count} required) interface contacts and {num_hydrogen_bonds} (of {hydrogen_bond_count} required) hydrogen bonds" + ) + filtered_pdb.df[atom_df_name] = filtered_pdb.df[atom_df_name][ + filtered_pdb.df[atom_df_name][chain_id_col] != chain1 + ] + return filtered_pdb def write_out_pdb_chain_groups( @@ -2032,7 +2115,9 @@ def write_out_pdb_chain_groups( pdb_chains = self.select_pdb_by_criterion( pdb, "chain_id", chains, entry_pdb_code ) - num_pdb_chains = len(pdb_chains.df[atom_df_name].chain_id.unique().tolist()) + num_pdb_chains = len( + pdb_chains.df[atom_df_name].chain_id.unique().tolist() + ) if filter_for_interface_contacts and num_pdb_chains > 1: pdb_chains = self.filter_chains_by_interface_criteria( pdb=pdb_chains,