Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix and improve the utilities for PyG #234

Merged
merged 58 commits into from
Feb 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
54613aa
Fix bug in `add_k_nn_edges`.
anton-bushuiev Nov 2, 2022
1de248f
Extend `add_k_nn_edges`.
anton-bushuiev Nov 2, 2022
27ee8af
Add types to docstring
a-r-j Nov 2, 2022
018fd9c
Update changelog
a-r-j Nov 2, 2022
71fee7a
Add `kind_name` argument
anton-bushuiev Nov 2, 2022
74968ce
Test `filter_distmat`
anton-bushuiev Nov 3, 2022
77a89c6
Merge branch 'master' of https://github.com/anton-bushuiev/graphein
anton-bushuiev Nov 3, 2022
c91cede
Merge branch 'a-r-j:master' into master
anton-bushuiev Nov 3, 2022
713d0e3
Merge branch 'master' of https://github.com/anton-bushuiev/graphein
anton-bushuiev Nov 3, 2022
beb15d3
Set default value of `long_interaction_threshold` to 0
anton-bushuiev Nov 3, 2022
584c9f9
Fix filtering bug in `add_k_nn_edges`
anton-bushuiev Nov 4, 2022
b9cc99b
Test `add_k_nn_edges`
anton-bushuiev Nov 4, 2022
fd1b36b
Refactor with `add_edge`
anton-bushuiev Nov 4, 2022
fdc8b96
Fix bug for empty `edges_to_excl`
anton-bushuiev Nov 10, 2022
5075462
Improve `convert_nx_to_pyg`
anton-bushuiev Nov 10, 2022
21f10a1
Fix bug in `plot_pyg_data`
anton-bushuiev Nov 10, 2022
febaa2b
Test `convert_nx_to_pyg` on multimers
anton-bushuiev Nov 10, 2022
48941fa
Merge
anton-bushuiev Nov 10, 2022
e856693
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 10, 2022
9b89b44
Update `CHANGELOG.md`
anton-bushuiev Nov 10, 2022
c3a5e84
Merge branch 'master' of https://github.com/anton-bushuiev/graphein
anton-bushuiev Nov 10, 2022
a80a387
Fix version in `CHANGELOG.md`
anton-bushuiev Nov 10, 2022
629a61c
Handle corner cases
anton-bushuiev Nov 10, 2022
f1fcc29
Handle NaNs in coordinatess
anton-bushuiev Nov 10, 2022
f54a41f
Add PyG install to CI
a-r-j Nov 13, 2022
05f2ef0
typo in CI config
a-r-j Nov 13, 2022
b5156d8
bump torch versions in CI
a-r-j Nov 13, 2022
7f8c9c1
make pyg-related tests conditional pyg installation
a-r-j Nov 13, 2022
daa5c96
Try fixing graph attributes
a-r-j Dec 7, 2022
421e628
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 7, 2022
6b126a8
Merge branch 'master' into master
a-r-j Dec 18, 2022
326442d
Fix typo and extend amino acid 3to1, 1to3 mappings
anton-bushuiev Jan 25, 2023
b3dc713
Merge remote-tracking branch 'origin/master'
anton-bushuiev Jan 25, 2023
decac66
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 25, 2023
fb42f6b
Merge branch 'a-r-j:master' into master
anton-bushuiev Jan 25, 2023
57d0c97
Adapt imports of amino acid codes
anton-bushuiev Jan 26, 2023
e16fd21
Merge branch 'master' of https://github.com/anton-bushuiev/graphein
anton-bushuiev Jan 26, 2023
ed1504a
Merge branch 'a-r-j:master' into master
anton-bushuiev Jan 28, 2023
9ac298a
add semicolon to version
a-r-j Jan 28, 2023
82c6e7f
remove wildcard version number for pyyaml
a-r-j Jan 28, 2023
5e80ee0
Merge branch 'master' into master
a-r-j Jan 30, 2023
eca0cfa
fix typo
a-r-j Jan 30, 2023
3930a53
fix additonal typos
a-r-j Jan 30, 2023
a5903cf
Extend aggregation to vectors
anton-bushuiev Feb 9, 2023
e39a5a7
Implement `aggregate_feature_over_residues`
anton-bushuiev Feb 9, 2023
7179825
Merge remote-tracking branch 'origin/master'
anton-bushuiev Feb 9, 2023
53430aa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 9, 2023
edd58ef
Add docstring and aggregation type
a-r-j Feb 9, 2023
cc8fa07
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 9, 2023
e538cc8
import literal from typing extensions
a-r-j Feb 9, 2023
df9f9fe
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 9, 2023
c459122
Merge branch 'a-r-j:master' into master
anton-bushuiev Feb 9, 2023
9092358
Add missing `median` in exception message
anton-bushuiev Feb 9, 2023
dc679ae
Fix `nullcontext`
anton-bushuiev Feb 9, 2023
bd1f4fa
fix dataset test
a-r-j Feb 9, 2023
d1f1c8c
fix division by zero errors in edge colouring
a-r-j Feb 9, 2023
00f99fd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 9, 2023
c3f8554
update changlelog
a-r-j Feb 10, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
conda-channels: "conda-forge, salilab, pytorch, pyg"

#- name: Set up Python ${{ matrix.python-version }}
# uses: actions/setup-python@v2
# with:
Expand Down
16 changes: 15 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,29 @@
* [Logging] - [#242](https://github.com/a-r-j/graphein/pull/242) Adds control of protein graph construction logging. Resolves [#238](https://github.com/a-r-j/graphein/issues/238)

#### Protein
* [Feature] - [#229](https://github.com/a-r-j/graphein/pull/220) Adds support for filtering KNN edges based on self-loops and chain membership. Contribution by @anton-bushuiev.
* [Feature] - [#234](https://github.com/a-r-j/graphein/pull/234) Adds support for aggregating node features over residues (`graphein.protein.features.sequence.utils.aggregate_feature_over_residues`).
* [Bugfix] - [#234](https://github.com/a-r-j/graphein/pull/234) fixes use of nullcontext in silent graph construction.
* [Bugfix] - [#234](https://github.com/a-r-j/graphein/pull/234) Fixes division by zero errors for edge colouring in visualisation.
* [Bugfix] - [#254](https://github.com/a-r-j/graphein/pull/254) Fix peptide bond addition for all atom graphs.
* [Bugfix] - [#223](https://github.com/a-r-j/graphein/pull/220) Fix handling of insertions in protein graphs. Insertions are now given IDs like: `A:SER:12:A`. Contribution by @manonreau.
* [Bugfix] - [#226](https://github.com/a-r-j/graphein/pull/226) Catches failed AF2 structure downloads [#225](https://github.com/a-r-j/graphein/issues/225)
* [Feature] - [#229](https://github.com/a-r-j/graphein/pull/220) Adds support for filtering KNN edges based on self-loops and chain membership. Contribution by @anton-bushuiev.

* [Bugfix] - [#229](https://github.com/a-r-j/graphein/pull/220) Fixes bug in KNN edge computation. Contribution by @anton-bushuiev.
* [Bugfix] - [#220](https://github.com/a-r-j/graphein/pull/220) Fixes edge metadata conversion to PyG. Contribution by @manonreau.
* [Bugfix] - [#220](https://github.com/a-r-j/graphein/pull/220) Fixes centroid atom grouping & avoids unnecessary edge computation where none are found. Contribution by @manonreau.


#### ML
* [Bugfix] - [#234](https://github.com/a-r-j/graphein/pull/234) - Fixes bugs and improves `conversion.convert_nx_to_pyg` and `visualisation.plot_pyg_data`. Removes distance matrix (`dist_mat`) from defualt set of features converted to tensor.

#### Utils
* [Improvement] - [#234](https://github.com/a-r-j/graphein/pull/234) - Adds `parse_aggregation_type` to retrieve aggregation functions.

#### Constants
* [Improvement] - [#234](https://github.com/a-r-j/graphein/pull/234) - Adds 1 to 3 mappings to `graphein.protein.resi_atoms`.


#### Documentation
* [Tensor Module] - [#244](https://github.com/a-r-j/graphein/pull/244) Documents new graphein.protein.tensor module.
* [CI] - [#244](https://github.com/a-r-j/graphein/pull/244) Updates to intersphinx maps
Expand Down
49 changes: 36 additions & 13 deletions graphein/ml/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

import networkx as nx
import numpy as np
import torch
from loguru import logger as log

from graphein.utils.utils import import_message

Expand Down Expand Up @@ -130,7 +132,6 @@ def __init__(
columns = [
"edge_index",
"coords",
"dist_mat",
"name",
"node_id",
]
Expand All @@ -139,7 +140,6 @@ def __init__(
"b_factor",
"chain_id",
"coords",
"dist_mat",
"edge_index",
"kind",
"name",
Expand Down Expand Up @@ -273,31 +273,54 @@ def convert_nx_to_pyg(self, G: nx.Graph) -> Data:
G = nx.convert_node_labels_to_integers(G)

# Construct Edge Index
edge_index = torch.LongTensor(list(G.edges)).t().contiguous()
edge_index = (
torch.LongTensor(list(G.edges)).t().contiguous().view(2, -1)
)

# Add node features
node_feature_names = G.nodes(data=True)[0].keys()
for i, (_, feat_dict) in enumerate(G.nodes(data=True)):
for key, value in feat_dict.items():
if str(key) in self.columns:
data[str(key)] = (
[value] if i == 0 else data[str(key)] + [value]
)
key = str(key)
if key in self.columns:
if i == 0:
data[key] = []
data[key].append(value)

# Add edge features
for i, (_, _, feat_dict) in enumerate(G.edges(data=True)):
for key, value in feat_dict.items():
if str(key) in self.columns:
key = str(key)
if key in self.columns or key == "kind":
if i == 0:
data[str(key)] = []
data[str(key)].append(list(value))
data[key] = []
data[key].append(value)

# Add graph-level features
for feat_name in G.graph:
if str(feat_name) in self.columns:
data[str(feat_name)] = [G.graph[feat_name]]

if str(feat_name) not in node_feature_names:
data[str(feat_name)] = G.graph[feat_name]
if "edge_index" in self.columns:
data["edge_index"] = edge_index.view(2, -1)
data["edge_index"] = edge_index

# Split edge index by edge kind
kind_strs = np.array(list(map(lambda x: "_".join(x), data["kind"])))
for kind in set(kind_strs):
key = f"edge_index_{kind}"
if key in self.columns:
mask = kind_strs == kind
data[key] = edge_index[:, mask]
if "kind" not in self.columns:
del data["kind"]

# Convert everything possible to torch.Tensors
for key, val in data.items():
try:
data[key] = torch.tensor(np.array(val))
except Exception as e:
log.warning(e)
pass

data = Data.from_dict(data)
data.num_nodes = G.number_of_nodes()
Expand Down
17 changes: 11 additions & 6 deletions graphein/ml/visualisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,10 @@ def plot_pyg_data(
node_alpha: float = 0.7,
node_size_min: float = 20.0,
node_size_multiplier: float = 20.0,
node_size_feature: str = "degree",
label_node_ids: bool = True,
node_colour_map=plt.cm.plasma,
edge_color_map=plt.cm.plasma,
edge_colour_map=plt.cm.plasma,
colour_nodes_by: str = "residue_name",
colour_edges_by: Optional[str] = None,
) -> go.Figure:
Expand Down Expand Up @@ -75,15 +76,18 @@ def plot_pyg_data(
:param node_size_multiplier: Scales node size by a constant. Node sizes
reflect degree. Defaults to ``20.0``.
:type node_size_multiplier: float
:param node_size_feature: Which feature to scale the node size by. Defaults
to ``degree``.
:type node_size_feature: str
:param label_node_ids: bool indicating whether or not to plot ``node_id``
labels. Defaults to ``True``.
:type label_node_ids: bool
:param node_colour_map: colour map to use for nodes. Defaults to
``plt.cm.plasma``.
:type node_colour_map: plt.cm
:param edge_color_map: colour map to use for edges. Defaults to
:param edge_colour_map: colour map to use for edges. Defaults to
``plt.cm.plasma``.
:type edge_color_map: plt.cm
:type edge_colour_map: plt.cm
:param colour_nodes_by: Specifies how to colour nodes. ``"degree"``,
``"seq_position"`` or a node feature.
:type colour_nodes_by: str
Expand All @@ -100,15 +104,15 @@ def plot_pyg_data(

# Add metadata
nx_graph.name = x.name
nx_graph.graph["coords"] = x.coords[0]
nx_graph.graph["coords"] = x.coords
nx_graph.graph["dist_mat"] = x.dist_mat

# Assign coords and seq info to nodes
for i, (_, d) in enumerate(nx_graph.nodes(data=True)):
d["chain_id"] = x.node_id[i].split(":")[0]
d["residue_name"] = x.node_id[i].split(":")[1]
d["seq_position"] = x.node_id[i].split(":")[2]
d["coords"] = x.coords[0][i]
d["coords"] = x.coords[i]
if node_colour_tensor is not None:
d["colour"] = float(node_colour_tensor[i])

Expand All @@ -124,9 +128,10 @@ def plot_pyg_data(
node_alpha,
node_size_min,
node_size_multiplier,
node_size_feature,
label_node_ids,
node_colour_map,
edge_color_map,
edge_colour_map,
colour_nodes_by if node_colour_tensor is None else "colour",
colour_edges_by if edge_colour_tensor is None else "colour",
)
20 changes: 19 additions & 1 deletion graphein/protein/edges/distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def filter_distmat(
edges_to_excl.extend(list(product(nodes0, nodes1)))

# Filter distance matrix based on indices of edges to exclude
if len(exclude_edges):
if len(edges_to_excl):
row_idx_to_excl, col_idx_to_excl = zip(*edges_to_excl)
distmat.iloc[row_idx_to_excl, col_idx_to_excl] = INFINITE_DIST
distmat.iloc[col_idx_to_excl, row_idx_to_excl] = INFINITE_DIST
Expand Down Expand Up @@ -1087,9 +1087,18 @@ def add_k_nn_edges(
:return: Graph with knn-based edges added
:rtype: nx.Graph
"""
# Prepare dataframe
pdb_df = filter_dataframe(
G.graph["pdb_df"], "node_id", list(G.nodes()), True
)
if (
pdb_df["x_coord"].isna().sum()
or pdb_df["y_coord"].isna().sum()
or pdb_df["z_coord"].isna().sum()
):
raise ValueError("Coordinates contain a NaN value.")

# Construct distance matrix
dist_mat = compute_distmat(pdb_df)

# Filter edges
Expand All @@ -1100,6 +1109,12 @@ def add_k_nn_edges(
k -= 1
for n1, n2 in zip(G.nodes(), G.nodes()):
add_edge(G, n1, n2, kind_name)

# Reduce k if number of nodes is less (to avoid sklearn error)
# Note: - 1 because self-loops are not included
if G.number_of_nodes() - 1 < k:
k = G.number_of_nodes() - 1

if k == 0:
return

Expand All @@ -1114,6 +1129,9 @@ def add_k_nn_edges(
interacting_nodes = list(zip(outgoing, incoming))
log.info(f"Found: {len(interacting_nodes)} KNN edges")
for a1, a2 in interacting_nodes:
if dist_mat.loc[a1, a2] == INFINITE_DIST:
continue

# Get nodes IDs from indices
n1 = G.graph["pdb_df"].loc[a1, "node_id"]
n2 = G.graph["pdb_df"].loc[a2, "node_id"]
Expand Down
4 changes: 2 additions & 2 deletions graphein/protein/features/nodes/dssp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import pandas as pd
from Bio.PDB.DSSP import dssp_dict_from_pdb_file, residue_max_acc

from graphein.protein.resi_atoms import STANDARD_AMINO_ACID_MAPPING_TO_1_3
from graphein.protein.resi_atoms import STANDARD_AMINO_ACID_MAPPING_1_TO_3
from graphein.protein.utils import is_tool, save_pdb_df_to_pdb

DSSP_COLS = [
Expand Down Expand Up @@ -121,7 +121,7 @@ def add_dssp_df(

dssp_dict = parse_dssp_df(dssp_dict)
# Convert 1 letter aa code to 3 letter
dssp_dict["aa"] = dssp_dict["aa"].map(STANDARD_AMINO_ACID_MAPPING_TO_1_3)
dssp_dict["aa"] = dssp_dict["aa"].map(STANDARD_AMINO_ACID_MAPPING_1_TO_3)

# Resolve UNKs
dssp_dict.loc[dssp_dict["aa"] == "UNK", "aa"] = (
Expand Down
49 changes: 35 additions & 14 deletions graphein/protein/features/sequence/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import networkx as nx
import numpy as np

from graphein.utils.utils import parse_aggregation_type


def compute_feature_over_chains(
G: nx.Graph, func: Callable, feature_name: str
Expand Down Expand Up @@ -56,26 +58,45 @@ def aggregate_feature_over_chains(
:return: Graph with new aggregated feature.
:rtype: nx.Graph
"""

if aggregation_type == "max":
func = np.max
elif aggregation_type == "min":
func = np.min
elif aggregation_type == "mean":
func = np.mean
elif aggregation_type == "sum":
func = np.sum
else:
raise ValueError(
f"Unsupported aggregator: {aggregation_type}. Please use min, max, mean, sum"
)
func = parse_aggregation_type(aggregation_type)

G.graph[f"{feature_name}_{aggregation_type}"] = func(
[G.graph[f"{feature_name}_{c}"] for c in G.graph["chain_ids"]]
[G.graph[f"{feature_name}_{c}"] for c in G.graph["chain_ids"]], axis=0
)
return G


def aggregate_feature_over_residues(
G: nx.Graph,
feature_name: str,
aggregation_type: str,
) -> nx.Graph:
"""
Performs aggregation of a given feature over chains in a graph to produce an aggregated value.

:param G: nx.Graph protein structure graph.
:type G: nx.Graph
:param feature_name: Name of feature to aggregate.
:type feature_name: str
:param aggregation_type: Type of aggregation to perform (``"min"``, ``"max"``, ``"mean"``, ``"sum"``).
:type aggregation_type: str
:raises ValueError: If ``aggregation_type`` is not one of ``"min"``, ``"max"``, ``"mean"``, ``"sum"``.
:return: Graph with new aggregated feature.
:rtype: nx.Graph
"""
func = parse_aggregation_type(aggregation_type)

for c in G.graph["chain_ids"]:
chain_features = []
for n in G.nodes:
if G.nodes[n]["chain_id"] == c:
chain_features.append(G.nodes[n][feature_name])
G.graph[f"{feature_name}_{aggregation_type}_{c}"] = func(
chain_features, axis=0
)
return G


def sequence_to_ngram(sequence: str, N: int) -> List[str]:
"""
Chops a sequence into overlapping N-grams (substrings of length ``N``).
Expand Down
2 changes: 1 addition & 1 deletion graphein/protein/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,7 @@ def construct_graph(
config = ProteinGraphConfig()

# Use progress tracking context if in verbose mode
context = Progress(transient=True) if verbose else nullcontext
context = Progress(transient=True) if verbose else nullcontext()
with context as progress:
if verbose:
task1 = progress.add_task("Reading PDB file...", total=1)
Expand Down
Loading