diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index fb18c220..4e556c5f 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -45,7 +45,6 @@ jobs: with: python-version: ${{ matrix.python-version }} conda-channels: "conda-forge, salilab, pytorch, pyg" - #- name: Set up Python ${{ matrix.python-version }} # uses: actions/setup-python@v2 # with: diff --git a/CHANGELOG.md b/CHANGELOG.md index 81d51779..4ef10e4a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,15 +13,29 @@ * [Logging] - [#242](https://github.com/a-r-j/graphein/pull/242) Adds control of protein graph construction logging. Resolves [#238](https://github.com/a-r-j/graphein/issues/238) #### Protein +* [Feature] - [#229](https://github.com/a-r-j/graphein/pull/220) Adds support for filtering KNN edges based on self-loops and chain membership. Contribution by @anton-bushuiev. +* [Feature] - [#234](https://github.com/a-r-j/graphein/pull/234) Adds support for aggregating node features over residues (`graphein.protein.features.sequence.utils.aggregate_feature_over_residues`). +* [Bugfix] - [#234](https://github.com/a-r-j/graphein/pull/234) fixes use of nullcontext in silent graph construction. +* [Bugfix] - [#234](https://github.com/a-r-j/graphein/pull/234) Fixes division by zero errors for edge colouring in visualisation. * [Bugfix] - [#254](https://github.com/a-r-j/graphein/pull/254) Fix peptide bond addition for all atom graphs. * [Bugfix] - [#223](https://github.com/a-r-j/graphein/pull/220) Fix handling of insertions in protein graphs. Insertions are now given IDs like: `A:SER:12:A`. Contribution by @manonreau. * [Bugfix] - [#226](https://github.com/a-r-j/graphein/pull/226) Catches failed AF2 structure downloads [#225](https://github.com/a-r-j/graphein/issues/225) -* [Feature] - [#229](https://github.com/a-r-j/graphein/pull/220) Adds support for filtering KNN edges based on self-loops and chain membership. Contribution by @anton-bushuiev. + * [Bugfix] - [#229](https://github.com/a-r-j/graphein/pull/220) Fixes bug in KNN edge computation. Contribution by @anton-bushuiev. * [Bugfix] - [#220](https://github.com/a-r-j/graphein/pull/220) Fixes edge metadata conversion to PyG. Contribution by @manonreau. * [Bugfix] - [#220](https://github.com/a-r-j/graphein/pull/220) Fixes centroid atom grouping & avoids unnecessary edge computation where none are found. Contribution by @manonreau. +#### ML +* [Bugfix] - [#234](https://github.com/a-r-j/graphein/pull/234) - Fixes bugs and improves `conversion.convert_nx_to_pyg` and `visualisation.plot_pyg_data`. Removes distance matrix (`dist_mat`) from defualt set of features converted to tensor. + +#### Utils +* [Improvement] - [#234](https://github.com/a-r-j/graphein/pull/234) - Adds `parse_aggregation_type` to retrieve aggregation functions. + +#### Constants +* [Improvement] - [#234](https://github.com/a-r-j/graphein/pull/234) - Adds 1 to 3 mappings to `graphein.protein.resi_atoms`. + + #### Documentation * [Tensor Module] - [#244](https://github.com/a-r-j/graphein/pull/244) Documents new graphein.protein.tensor module. * [CI] - [#244](https://github.com/a-r-j/graphein/pull/244) Updates to intersphinx maps diff --git a/graphein/ml/conversion.py b/graphein/ml/conversion.py index 52f85fca..63a87afd 100644 --- a/graphein/ml/conversion.py +++ b/graphein/ml/conversion.py @@ -12,6 +12,8 @@ import networkx as nx import numpy as np +import torch +from loguru import logger as log from graphein.utils.utils import import_message @@ -130,7 +132,6 @@ def __init__( columns = [ "edge_index", "coords", - "dist_mat", "name", "node_id", ] @@ -139,7 +140,6 @@ def __init__( "b_factor", "chain_id", "coords", - "dist_mat", "edge_index", "kind", "name", @@ -273,31 +273,54 @@ def convert_nx_to_pyg(self, G: nx.Graph) -> Data: G = nx.convert_node_labels_to_integers(G) # Construct Edge Index - edge_index = torch.LongTensor(list(G.edges)).t().contiguous() + edge_index = ( + torch.LongTensor(list(G.edges)).t().contiguous().view(2, -1) + ) # Add node features + node_feature_names = G.nodes(data=True)[0].keys() for i, (_, feat_dict) in enumerate(G.nodes(data=True)): for key, value in feat_dict.items(): - if str(key) in self.columns: - data[str(key)] = ( - [value] if i == 0 else data[str(key)] + [value] - ) + key = str(key) + if key in self.columns: + if i == 0: + data[key] = [] + data[key].append(value) # Add edge features for i, (_, _, feat_dict) in enumerate(G.edges(data=True)): for key, value in feat_dict.items(): - if str(key) in self.columns: + key = str(key) + if key in self.columns or key == "kind": if i == 0: - data[str(key)] = [] - data[str(key)].append(list(value)) + data[key] = [] + data[key].append(value) # Add graph-level features for feat_name in G.graph: if str(feat_name) in self.columns: - data[str(feat_name)] = [G.graph[feat_name]] - + if str(feat_name) not in node_feature_names: + data[str(feat_name)] = G.graph[feat_name] if "edge_index" in self.columns: - data["edge_index"] = edge_index.view(2, -1) + data["edge_index"] = edge_index + + # Split edge index by edge kind + kind_strs = np.array(list(map(lambda x: "_".join(x), data["kind"]))) + for kind in set(kind_strs): + key = f"edge_index_{kind}" + if key in self.columns: + mask = kind_strs == kind + data[key] = edge_index[:, mask] + if "kind" not in self.columns: + del data["kind"] + + # Convert everything possible to torch.Tensors + for key, val in data.items(): + try: + data[key] = torch.tensor(np.array(val)) + except Exception as e: + log.warning(e) + pass data = Data.from_dict(data) data.num_nodes = G.number_of_nodes() diff --git a/graphein/ml/visualisation.py b/graphein/ml/visualisation.py index db3f0615..cbc08195 100644 --- a/graphein/ml/visualisation.py +++ b/graphein/ml/visualisation.py @@ -42,9 +42,10 @@ def plot_pyg_data( node_alpha: float = 0.7, node_size_min: float = 20.0, node_size_multiplier: float = 20.0, + node_size_feature: str = "degree", label_node_ids: bool = True, node_colour_map=plt.cm.plasma, - edge_color_map=plt.cm.plasma, + edge_colour_map=plt.cm.plasma, colour_nodes_by: str = "residue_name", colour_edges_by: Optional[str] = None, ) -> go.Figure: @@ -75,15 +76,18 @@ def plot_pyg_data( :param node_size_multiplier: Scales node size by a constant. Node sizes reflect degree. Defaults to ``20.0``. :type node_size_multiplier: float + :param node_size_feature: Which feature to scale the node size by. Defaults + to ``degree``. + :type node_size_feature: str :param label_node_ids: bool indicating whether or not to plot ``node_id`` labels. Defaults to ``True``. :type label_node_ids: bool :param node_colour_map: colour map to use for nodes. Defaults to ``plt.cm.plasma``. :type node_colour_map: plt.cm - :param edge_color_map: colour map to use for edges. Defaults to + :param edge_colour_map: colour map to use for edges. Defaults to ``plt.cm.plasma``. - :type edge_color_map: plt.cm + :type edge_colour_map: plt.cm :param colour_nodes_by: Specifies how to colour nodes. ``"degree"``, ``"seq_position"`` or a node feature. :type colour_nodes_by: str @@ -100,7 +104,7 @@ def plot_pyg_data( # Add metadata nx_graph.name = x.name - nx_graph.graph["coords"] = x.coords[0] + nx_graph.graph["coords"] = x.coords nx_graph.graph["dist_mat"] = x.dist_mat # Assign coords and seq info to nodes @@ -108,7 +112,7 @@ def plot_pyg_data( d["chain_id"] = x.node_id[i].split(":")[0] d["residue_name"] = x.node_id[i].split(":")[1] d["seq_position"] = x.node_id[i].split(":")[2] - d["coords"] = x.coords[0][i] + d["coords"] = x.coords[i] if node_colour_tensor is not None: d["colour"] = float(node_colour_tensor[i]) @@ -124,9 +128,10 @@ def plot_pyg_data( node_alpha, node_size_min, node_size_multiplier, + node_size_feature, label_node_ids, node_colour_map, - edge_color_map, + edge_colour_map, colour_nodes_by if node_colour_tensor is None else "colour", colour_edges_by if edge_colour_tensor is None else "colour", ) diff --git a/graphein/protein/edges/distance.py b/graphein/protein/edges/distance.py index 7abf9470..44bcd55d 100644 --- a/graphein/protein/edges/distance.py +++ b/graphein/protein/edges/distance.py @@ -127,7 +127,7 @@ def filter_distmat( edges_to_excl.extend(list(product(nodes0, nodes1))) # Filter distance matrix based on indices of edges to exclude - if len(exclude_edges): + if len(edges_to_excl): row_idx_to_excl, col_idx_to_excl = zip(*edges_to_excl) distmat.iloc[row_idx_to_excl, col_idx_to_excl] = INFINITE_DIST distmat.iloc[col_idx_to_excl, row_idx_to_excl] = INFINITE_DIST @@ -1087,9 +1087,18 @@ def add_k_nn_edges( :return: Graph with knn-based edges added :rtype: nx.Graph """ + # Prepare dataframe pdb_df = filter_dataframe( G.graph["pdb_df"], "node_id", list(G.nodes()), True ) + if ( + pdb_df["x_coord"].isna().sum() + or pdb_df["y_coord"].isna().sum() + or pdb_df["z_coord"].isna().sum() + ): + raise ValueError("Coordinates contain a NaN value.") + + # Construct distance matrix dist_mat = compute_distmat(pdb_df) # Filter edges @@ -1100,6 +1109,12 @@ def add_k_nn_edges( k -= 1 for n1, n2 in zip(G.nodes(), G.nodes()): add_edge(G, n1, n2, kind_name) + + # Reduce k if number of nodes is less (to avoid sklearn error) + # Note: - 1 because self-loops are not included + if G.number_of_nodes() - 1 < k: + k = G.number_of_nodes() - 1 + if k == 0: return @@ -1114,6 +1129,9 @@ def add_k_nn_edges( interacting_nodes = list(zip(outgoing, incoming)) log.info(f"Found: {len(interacting_nodes)} KNN edges") for a1, a2 in interacting_nodes: + if dist_mat.loc[a1, a2] == INFINITE_DIST: + continue + # Get nodes IDs from indices n1 = G.graph["pdb_df"].loc[a1, "node_id"] n2 = G.graph["pdb_df"].loc[a2, "node_id"] diff --git a/graphein/protein/features/nodes/dssp.py b/graphein/protein/features/nodes/dssp.py index 87eb5908..ea0ec237 100644 --- a/graphein/protein/features/nodes/dssp.py +++ b/graphein/protein/features/nodes/dssp.py @@ -14,7 +14,7 @@ import pandas as pd from Bio.PDB.DSSP import dssp_dict_from_pdb_file, residue_max_acc -from graphein.protein.resi_atoms import STANDARD_AMINO_ACID_MAPPING_TO_1_3 +from graphein.protein.resi_atoms import STANDARD_AMINO_ACID_MAPPING_1_TO_3 from graphein.protein.utils import is_tool, save_pdb_df_to_pdb DSSP_COLS = [ @@ -121,7 +121,7 @@ def add_dssp_df( dssp_dict = parse_dssp_df(dssp_dict) # Convert 1 letter aa code to 3 letter - dssp_dict["aa"] = dssp_dict["aa"].map(STANDARD_AMINO_ACID_MAPPING_TO_1_3) + dssp_dict["aa"] = dssp_dict["aa"].map(STANDARD_AMINO_ACID_MAPPING_1_TO_3) # Resolve UNKs dssp_dict.loc[dssp_dict["aa"] == "UNK", "aa"] = ( diff --git a/graphein/protein/features/sequence/utils.py b/graphein/protein/features/sequence/utils.py index 27866e00..d6420429 100644 --- a/graphein/protein/features/sequence/utils.py +++ b/graphein/protein/features/sequence/utils.py @@ -9,6 +9,8 @@ import networkx as nx import numpy as np +from graphein.utils.utils import parse_aggregation_type + def compute_feature_over_chains( G: nx.Graph, func: Callable, feature_name: str @@ -56,26 +58,45 @@ def aggregate_feature_over_chains( :return: Graph with new aggregated feature. :rtype: nx.Graph """ - - if aggregation_type == "max": - func = np.max - elif aggregation_type == "min": - func = np.min - elif aggregation_type == "mean": - func = np.mean - elif aggregation_type == "sum": - func = np.sum - else: - raise ValueError( - f"Unsupported aggregator: {aggregation_type}. Please use min, max, mean, sum" - ) + func = parse_aggregation_type(aggregation_type) G.graph[f"{feature_name}_{aggregation_type}"] = func( - [G.graph[f"{feature_name}_{c}"] for c in G.graph["chain_ids"]] + [G.graph[f"{feature_name}_{c}"] for c in G.graph["chain_ids"]], axis=0 ) return G +def aggregate_feature_over_residues( + G: nx.Graph, + feature_name: str, + aggregation_type: str, +) -> nx.Graph: + """ + Performs aggregation of a given feature over chains in a graph to produce an aggregated value. + + :param G: nx.Graph protein structure graph. + :type G: nx.Graph + :param feature_name: Name of feature to aggregate. + :type feature_name: str + :param aggregation_type: Type of aggregation to perform (``"min"``, ``"max"``, ``"mean"``, ``"sum"``). + :type aggregation_type: str + :raises ValueError: If ``aggregation_type`` is not one of ``"min"``, ``"max"``, ``"mean"``, ``"sum"``. + :return: Graph with new aggregated feature. + :rtype: nx.Graph + """ + func = parse_aggregation_type(aggregation_type) + + for c in G.graph["chain_ids"]: + chain_features = [] + for n in G.nodes: + if G.nodes[n]["chain_id"] == c: + chain_features.append(G.nodes[n][feature_name]) + G.graph[f"{feature_name}_{aggregation_type}_{c}"] = func( + chain_features, axis=0 + ) + return G + + def sequence_to_ngram(sequence: str, N: int) -> List[str]: """ Chops a sequence into overlapping N-grams (substrings of length ``N``). diff --git a/graphein/protein/graphs.py b/graphein/protein/graphs.py index 66712f77..b81dfdbe 100644 --- a/graphein/protein/graphs.py +++ b/graphein/protein/graphs.py @@ -694,7 +694,7 @@ def construct_graph( config = ProteinGraphConfig() # Use progress tracking context if in verbose mode - context = Progress(transient=True) if verbose else nullcontext + context = Progress(transient=True) if verbose else nullcontext() with context as progress: if verbose: task1 = progress.add_task("Reading PDB file...", total=1) diff --git a/graphein/protein/resi_atoms.py b/graphein/protein/resi_atoms.py index e4039bd8..af052eb7 100644 --- a/graphein/protein/resi_atoms.py +++ b/graphein/protein/resi_atoms.py @@ -117,31 +117,41 @@ ``"X"`` denotes unknown (``"UNK"`` or sometimes ``"XAA"``). """ -STANDARD_AMINO_ACID_MAPPING_TO_1_3: Dict[str, str] = { - "A": "ALA", - "C": "CYS", - "D": "ASP", - "E": "GLU", - "F": "PHE", - "G": "GLY", - "H": "HIS", - "I": "ILE", - "K": "LYS", - "L": "LEU", - "M": "MET", - "N": "ASN", - "O": "PYL", - "P": "PRO", - "Q": "GLN", - "R": "ARG", - "S": "SER", - "T": "THR", - "U": "SEC", - "V": "VAL", - "W": "TRP", - "Y": "TYR", - "X": "UNK", +STANDARD_AMINO_ACID_MAPPING_3_TO_1: Dict[str, str] = { + "ALA": "A", + "CYS": "C", + "ASP": "D", + "GLU": "E", + "PHE": "F", + "GLY": "G", + "HIS": "H", + "ILE": "I", + "LYS": "K", + "LEU": "L", + "MET": "M", + "ASN": "N", + "PYL": "O", + "PRO": "P", + "GLN": "Q", + "ARG": "R", + "SER": "S", + "THR": "T", + "SEC": "U", + "VAL": "V", + "TRP": "W", + "TYR": "Y", + "UNK": "X", } +""" +Mapping of 3-letter standard amino acids codes to their one-letter form. +""" + +STANDARD_AMINO_ACID_MAPPING_1_TO_3 = { + v: k for k, v in STANDARD_AMINO_ACID_MAPPING_3_TO_1.items() +} +""" +Mapping of 1-letter standard amino acids codes to their three-letter form. +""" NON_STANDARD_AMINO_ACID_MAPPING_3_TO_1: Dict[str, str] = { "CGU": "E", @@ -158,6 +168,16 @@ """ +NON_STANDARD_AMINO_ACID_MAPPING_1_TO_3 = { + v: k for k, v in NON_STANDARD_AMINO_ACID_MAPPING_3_TO_1.items() +} +""" +Mapping of 1-letter non-standard amino acids codes to their three-letter form. + +See: http://ligand-expo.rcsb.org/ +""" + + PROTEIN_ATOMS: List[str] = [ "N", "CA", @@ -387,6 +407,7 @@ } """Default Ordering of atoms (including non-standard residues) in (dimension 1 of) a protein structure tensor.""" + BOND_TYPES: List[str] = [ "hydrophobic", "disulfide", diff --git a/graphein/protein/tensor/io.py b/graphein/protein/tensor/io.py index d8151b6d..4b203e66 100644 --- a/graphein/protein/tensor/io.py +++ b/graphein/protein/tensor/io.py @@ -26,7 +26,7 @@ ATOM_NUMBERING, ELEMENT_SYMBOL_MAP, PROTEIN_ATOMS, - STANDARD_AMINO_ACID_MAPPING_TO_1_3, + STANDARD_AMINO_ACID_MAPPING_1_TO_3, STANDARD_AMINO_ACIDS, ) from .representation import get_full_atom_coords @@ -374,7 +374,7 @@ def to_dataframe( ) if isinstance(residue_types, torch.Tensor): residue_types = [ - STANDARD_AMINO_ACID_MAPPING_TO_1_3[STANDARD_AMINO_ACIDS[a]] + STANDARD_AMINO_ACID_MAPPING_1_TO_3[STANDARD_AMINO_ACIDS[a]] for a in residue_types ] residue_types = [residue_types[a - 1] for a in res_nums] diff --git a/graphein/protein/tensor/testing.py b/graphein/protein/tensor/testing.py index a29e0afa..ef881fd2 100644 --- a/graphein/protein/tensor/testing.py +++ b/graphein/protein/tensor/testing.py @@ -8,7 +8,7 @@ from ..resi_atoms import ( ATOM_NUMBERING_MODIFIED, - STANDARD_AMINO_ACID_MAPPING_TO_1_3, + STANDARD_AMINO_ACID_MAPPING_1_TO_3, ) from .sequence import get_atom_indices from .types import AtomTensor, BackboneTensor, CoordTensor, ResidueTensor @@ -97,7 +97,7 @@ def has_complete_residue( :func:`graphein.protein.tensor.testing.is_complete_structure` """ if len(residue_type) == 1: - residue_type = STANDARD_AMINO_ACID_MAPPING_TO_1_3[residue_type] + residue_type = STANDARD_AMINO_ACID_MAPPING_1_TO_3[residue_type] true_residue_indices = get_atom_indices()[residue_type] def _get_index(y: torch.Tensor) -> Tuple[int, ...]: diff --git a/graphein/protein/visualisation.py b/graphein/protein/visualisation.py index 01f9a320..e188ca82 100644 --- a/graphein/protein/visualisation.py +++ b/graphein/protein/visualisation.py @@ -96,7 +96,7 @@ def colour_nodes( if colour_by == "degree": # Get max number of edges connected to a single node edge_max = max(G.degree[i] for i in G.nodes()) - colors = [colour_map(G.degree[i] / edge_max) for i in G.nodes()] + colors = [colour_map(G.degree[i] / (edge_max + 1)) for i in G.nodes()] elif colour_by == "seq_position": colors = [colour_map(i / n) for i in range(n)] elif colour_by == "chain": @@ -104,7 +104,7 @@ def colour_nodes( chain_colours = dict( zip(chains, list(colour_map(1 / len(chains), 1, len(chains)))) ) - colors = [chain_colours[d["chain_id"]] for n, d in G.nodes(data=True)] + colors = [chain_colours[d["chain_id"]] for _, d in G.nodes(data=True)] elif colour_by == "plddt": levels: List[str] = ["Very High", "Confident", "Low", "Very Low"] mapping = dict(zip(sorted(levels), count())) @@ -173,7 +173,7 @@ def colour_edges( edge_types = set(nx.get_edge_attributes(G, colour_by).values()) mapping = dict(zip(sorted(edge_types), count())) colors = [ - colour_map(mapping[d[colour_by]] / len(edge_types)) + colour_map(mapping[d[colour_by]] / (len(edge_types) + 1)) for _, _, d in G.edges(data=True) ] diff --git a/graphein/utils/utils.py b/graphein/utils/utils.py index dd46e4c3..0021e8e7 100644 --- a/graphein/utils/utils.py +++ b/graphein/utils/utils.py @@ -17,6 +17,10 @@ import pandas as pd import xarray as xr from Bio.Data.IUPACData import protein_letters_3to1 +from typing_extensions import Literal + +AggregationType: List["sum", "mean", "max", "min", "median"] +"""Types of aggregations for features.""" def onek_encoding_unk( @@ -339,9 +343,9 @@ def import_message( :type submodule: str :param package: External package this submodule relies on. :type package: str - :param conda_channel: Conda channel package can be installed from, if at all. Defaults to None + :param conda_channel: Conda channel package can be installed from, if at all. Defaults to ``None``. :type conda_channel: str, optional - :param pip_install: Whether package can be installed via pip. Defaults to False + :param pip_install: Whether package can be installed via pip. Defaults to ``False``. :type pip_install: bool """ is_conda = os.path.exists(os.path.join(sys.prefix, "conda-meta")) @@ -373,7 +377,7 @@ def ping(host: str) -> bool: :param host: IP or hostname :type host: str - :returns: True if host responds to a ping request. + :returns: ``True`` if host responds to a ping request. :rtype: bool """ @@ -384,3 +388,30 @@ def ping(host: str) -> bool: command = ["ping", param, "1", host] return subprocess.call(command) == 0 + + +def parse_aggregation_type(aggregation_type: AggregationType) -> Callable: + """Returns an aggregation function by name + + :param aggregation_type: One of: ``["max", "min", "mean", "median", "sum"]``. + :type aggregration_type: AggregationType + :returns: NumPy aggregation function. + :rtype: Callable + :raises ValueError: if aggregation type is not supported. + """ + if aggregation_type == "max": + func = np.max + elif aggregation_type == "min": + func = np.min + elif aggregation_type == "mean": + func = np.mean + elif aggregation_type == "median": + func = np.median + elif aggregation_type == "sum": + func = np.sum + else: + raise ValueError( + f"Unsupported aggregator: {aggregation_type}." + f" Please use min, max, mean, median, sum" + ) + return func diff --git a/tests/ml/test_conversion.py b/tests/ml/test_conversion.py new file mode 100644 index 00000000..e1f44964 --- /dev/null +++ b/tests/ml/test_conversion.py @@ -0,0 +1,84 @@ +"""Tests for graph format conversion procedures.""" +from functools import partial + +import pytest +import torch + +from graphein.ml import GraphFormatConvertor +from graphein.protein.config import ProteinGraphConfig +from graphein.protein.edges.distance import add_k_nn_edges +from graphein.protein.features.nodes.amino_acid import amino_acid_one_hot +from graphein.protein.graphs import construct_graph + +try: + import torch_geometric + + PYG_AVAIL = True +except ImportError: + PYG_AVAIL = False + + +@pytest.mark.skipif(not PYG_AVAIL, reason="PyG not installed") +@pytest.mark.parametrize("pdb_code", ["10gs", "1bui", "1cw3"]) +def test_nx_to_pyg(pdb_code): + # Construct graph of a multimer protein complex + edge_funcs = { + "edge_construction_functions": [ + partial( + add_k_nn_edges, + k=1, + long_interaction_threshold=0, + exclude_edges=["inter"], + kind_name="intra", + ), + partial( + add_k_nn_edges, + k=1, + long_interaction_threshold=0, + exclude_edges=["intra"], + kind_name="inter", + ), + ] + } + node_feature_funcs = {"node_metadata_functions": [amino_acid_one_hot]} + config = ProteinGraphConfig(**edge_funcs, **node_feature_funcs) + g = construct_graph(config=config, pdb_code=pdb_code) + + # Convert to PyG + convertor = GraphFormatConvertor( + src_format="nx", + dst_format="pyg", + columns=[ + "coords", + "node_id", + "amino_acid_one_hot", + "edge_index_inter", + "edge_index_intra", + "edge_index", + ], + ) + data = convertor(g) + + # Test + # Nodes + assert len(data.node_id) == data.num_nodes + + # Coordinates + assert isinstance(data.coords, torch.Tensor) + assert data.coords.shape == torch.Size([data.num_nodes, 3]) + + # Features + assert isinstance(data.amino_acid_one_hot, torch.Tensor) + assert data.amino_acid_one_hot.shape == torch.Size([data.num_nodes, 20]) + + # Edges + assert isinstance(data.edge_index, torch.Tensor) + assert data.edge_index.shape[0] == 2 + assert isinstance(data.edge_index_inter, torch.Tensor) + assert data.edge_index_inter.shape[0] == 2 + assert isinstance(data.edge_index_intra, torch.Tensor) + assert data.edge_index_intra.shape[0] == 2 + assert ( + data.edge_index.shape[1] + == data.edge_index_inter.shape[1] + data.edge_index_intra.shape[1] + ) diff --git a/tests/ml/test_torch_geometric_dataset.py b/tests/ml/test_torch_geometric_dataset.py index 97dd8855..68c9f99f 100644 --- a/tests/ml/test_torch_geometric_dataset.py +++ b/tests/ml/test_torch_geometric_dataset.py @@ -5,7 +5,6 @@ import pytest from numpy.testing import assert_array_equal -from pandas.testing import assert_frame_equal import graphein.protein as gp from graphein.ml import GraphFormatConvertor @@ -56,7 +55,6 @@ def test_list_dataset(): assert d.node_id == graphs[i].node_id assert_array_equal(d.coords[0], graphs[i].coords[0]) assert d.name == graphs[i].name - assert_frame_equal(d.dist_mat[0], graphs[i].dist_mat[0]) assert d.num_nodes == graphs[i].num_nodes # Clean up shutil.rmtree(ROOT_DIR / "processed")