Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Aug 18, 2023
1 parent e6e9658 commit 8c821f0
Showing 1 changed file with 123 additions and 38 deletions.
161 changes: 123 additions & 38 deletions graphein/ml/datasets/pdb_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,28 @@

PRIMARY_INTERCHAIN_CONTACT_ATOMS_FOR_FILTERING: List[str] = ["CA", "C4'"]
SECONDARY_INTERCHAIN_CONTACT_ATOMS_NOT_FOR_FILTERING: List[str] = ["H"]
PRIMARY_HYDROGEN_BOND_ATOMS_FOR_FILTERING: List[str] = ["N", "O", "N1", "N9", "N3", "C2", "C4", "C5", "C6"]
SECONDARY_HYDROGEN_BOND_ATOMS_FOR_FILTERING: List[str] = ["N", "O", "N1", "N9", "N3", "C2", "C4", "C5", "C6"]
PRIMARY_HYDROGEN_BOND_ATOMS_FOR_FILTERING: List[str] = [
"N",
"O",
"N1",
"N9",
"N3",
"C2",
"C4",
"C5",
"C6",
]
SECONDARY_HYDROGEN_BOND_ATOMS_FOR_FILTERING: List[str] = [
"N",
"O",
"N1",
"N9",
"N3",
"C2",
"C4",
"C5",
"C6",
]


class PDBManager:
Expand Down Expand Up @@ -1828,10 +1848,18 @@ def select_pdb_by_criterion(
def filter_chains_by_interface_criteria(
self,
pdb: PandasPdb,
primary_interchain_contact_atoms_for_filtering: List[str] = PRIMARY_INTERCHAIN_CONTACT_ATOMS_FOR_FILTERING,
secondary_interchain_contact_atoms_not_for_filtering: List[str] = SECONDARY_INTERCHAIN_CONTACT_ATOMS_NOT_FOR_FILTERING,
primary_hydrogen_bond_atoms_for_filtering: List[str] = PRIMARY_HYDROGEN_BOND_ATOMS_FOR_FILTERING,
secondary_hydrogen_bond_atoms_for_filtering: List[str] = SECONDARY_HYDROGEN_BOND_ATOMS_FOR_FILTERING,
primary_interchain_contact_atoms_for_filtering: List[
str
] = PRIMARY_INTERCHAIN_CONTACT_ATOMS_FOR_FILTERING,
secondary_interchain_contact_atoms_not_for_filtering: List[
str
] = SECONDARY_INTERCHAIN_CONTACT_ATOMS_NOT_FOR_FILTERING,
primary_hydrogen_bond_atoms_for_filtering: List[
str
] = PRIMARY_HYDROGEN_BOND_ATOMS_FOR_FILTERING,
secondary_hydrogen_bond_atoms_for_filtering: List[
str
] = SECONDARY_HYDROGEN_BOND_ATOMS_FOR_FILTERING,
interface_contact_criterion: float = 7.0,
hydrogen_bond_criterion: float = 3.5,
interface_contact_count: int = 16,
Expand Down Expand Up @@ -1880,42 +1908,97 @@ def filter_chains_by_interface_criteria(
:rtype: PandasPdb
"""
filtered_pdb = copy.deepcopy(pdb)

atom_data = pdb.df[atom_df_name]
unique_chain_ids = atom_data[chain_id_col].unique()

interface_contact_atom_mask = atom_data[atom_name_col].isin(primary_interchain_contact_atoms_for_filtering)
interface_contact_other_atom_mask = ~atom_data[atom_name_col].isin(secondary_interchain_contact_atoms_not_for_filtering)
hydrogen_bond_atom_mask = atom_data[atom_name_col].isin(primary_hydrogen_bond_atoms_for_filtering)
hydrogen_bond_other_atom_mask = atom_data[atom_name_col].isin(secondary_hydrogen_bond_atoms_for_filtering)


interface_contact_atom_mask = atom_data[atom_name_col].isin(
primary_interchain_contact_atoms_for_filtering
)
interface_contact_other_atom_mask = ~atom_data[atom_name_col].isin(
secondary_interchain_contact_atoms_not_for_filtering
)
hydrogen_bond_atom_mask = atom_data[atom_name_col].isin(
primary_hydrogen_bond_atoms_for_filtering
)
hydrogen_bond_other_atom_mask = atom_data[atom_name_col].isin(
secondary_hydrogen_bond_atoms_for_filtering
)

for chain1 in unique_chain_ids:
interface_contact_chain1_mask = (atom_data[chain_id_col] == chain1) & interface_contact_atom_mask
hydrogen_bond_chain1_mask = (atom_data[chain_id_col] == chain1) & hydrogen_bond_atom_mask
interface_contact_chain1_residues = atom_data[interface_contact_chain1_mask]
interface_contact_chain1_mask = (
atom_data[chain_id_col] == chain1
) & interface_contact_atom_mask
hydrogen_bond_chain1_mask = (
atom_data[chain_id_col] == chain1
) & hydrogen_bond_atom_mask
interface_contact_chain1_residues = atom_data[
interface_contact_chain1_mask
]
hydrogen_bond_chain1_atoms = atom_data[hydrogen_bond_chain1_mask]

if np.sum(interface_contact_chain1_mask) == 0 or np.sum(hydrogen_bond_chain1_mask) == 0:

if (
np.sum(interface_contact_chain1_mask) == 0
or np.sum(hydrogen_bond_chain1_mask) == 0
):
continue

interface_contact_chain1_coords = interface_contact_chain1_residues[["x_coord", "y_coord", "z_coord"]].to_numpy()
interface_contact_non_chain1_coords = atom_data.loc[interface_contact_other_atom_mask & (atom_data[chain_id_col] != chain1), ["x_coord", "y_coord", "z_coord"]].to_numpy()
hydrogen_bond_chain1_coords = hydrogen_bond_chain1_atoms[["x_coord", "y_coord", "z_coord"]].to_numpy()
hydrogen_bond_non_chain1_coords = atom_data.loc[hydrogen_bond_other_atom_mask & (atom_data[chain_id_col] != chain1), ["x_coord", "y_coord", "z_coord"]].to_numpy()

interface_contact_distances = cdist(interface_contact_chain1_coords, interface_contact_non_chain1_coords, metric="euclidean")
hydrogen_bond_distances = cdist(hydrogen_bond_chain1_coords, hydrogen_bond_non_chain1_coords, metric="euclidean")

num_interface_contacts = np.sum(interface_contact_distances <= interface_contact_criterion, axis=1).sum()
chain_within_interface = (num_interface_contacts >= interface_contact_count).item()

num_hydrogen_bonds = np.sum(hydrogen_bond_distances <= hydrogen_bond_criterion, axis=1).sum()
chain_with_sufficient_bond_count = (num_hydrogen_bonds >= hydrogen_bond_count).item()

if not chain_within_interface or not chain_with_sufficient_bond_count:
log.info(f"Filtering out chain {chain1} within PDB {pdb.pdb_path}, as it contains {num_interface_contacts} (of {interface_contact_count} required) interface contacts and {num_hydrogen_bonds} (of {hydrogen_bond_count} required) hydrogen bonds")
filtered_pdb.df[atom_df_name] = filtered_pdb.df[atom_df_name][filtered_pdb.df[atom_df_name][chain_id_col] != chain1]


interface_contact_chain1_coords = (
interface_contact_chain1_residues[
["x_coord", "y_coord", "z_coord"]
].to_numpy()
)
interface_contact_non_chain1_coords = atom_data.loc[
interface_contact_other_atom_mask
& (atom_data[chain_id_col] != chain1),
["x_coord", "y_coord", "z_coord"],
].to_numpy()
hydrogen_bond_chain1_coords = hydrogen_bond_chain1_atoms[
["x_coord", "y_coord", "z_coord"]
].to_numpy()
hydrogen_bond_non_chain1_coords = atom_data.loc[
hydrogen_bond_other_atom_mask
& (atom_data[chain_id_col] != chain1),
["x_coord", "y_coord", "z_coord"],
].to_numpy()

interface_contact_distances = cdist(
interface_contact_chain1_coords,
interface_contact_non_chain1_coords,
metric="euclidean",
)
hydrogen_bond_distances = cdist(
hydrogen_bond_chain1_coords,
hydrogen_bond_non_chain1_coords,
metric="euclidean",
)

num_interface_contacts = np.sum(
interface_contact_distances <= interface_contact_criterion,
axis=1,
).sum()
chain_within_interface = (
num_interface_contacts >= interface_contact_count
).item()

num_hydrogen_bonds = np.sum(
hydrogen_bond_distances <= hydrogen_bond_criterion, axis=1
).sum()
chain_with_sufficient_bond_count = (
num_hydrogen_bonds >= hydrogen_bond_count
).item()

if (
not chain_within_interface
or not chain_with_sufficient_bond_count
):
log.info(
f"Filtering out chain {chain1} within PDB {pdb.pdb_path}, as it contains {num_interface_contacts} (of {interface_contact_count} required) interface contacts and {num_hydrogen_bonds} (of {hydrogen_bond_count} required) hydrogen bonds"
)
filtered_pdb.df[atom_df_name] = filtered_pdb.df[atom_df_name][
filtered_pdb.df[atom_df_name][chain_id_col] != chain1
]

return filtered_pdb

def write_out_pdb_chain_groups(
Expand Down Expand Up @@ -2032,7 +2115,9 @@ def write_out_pdb_chain_groups(
pdb_chains = self.select_pdb_by_criterion(
pdb, "chain_id", chains, entry_pdb_code
)
num_pdb_chains = len(pdb_chains.df[atom_df_name].chain_id.unique().tolist())
num_pdb_chains = len(
pdb_chains.df[atom_df_name].chain_id.unique().tolist()
)
if filter_for_interface_contacts and num_pdb_chains > 1:
pdb_chains = self.filter_chains_by_interface_criteria(
pdb=pdb_chains,
Expand Down

0 comments on commit 8c821f0

Please sign in to comment.