Skip to content

Commit

Permalink
Support PyG format graphs (#93)
Browse files Browse the repository at this point in the history
  • Loading branch information
heatingma authored Dec 27, 2023
1 parent d62eb9a commit c537694
Show file tree
Hide file tree
Showing 4 changed files with 271 additions and 43 deletions.
269 changes: 227 additions & 42 deletions pygmtools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1343,13 +1343,13 @@ def _get_md5(filename):


###################################################
# Support NetworkX and GraphML formats #
# Support NetworkX, GraphML formats and PyG #
###################################################


def build_aff_mat_from_networkx(G1:nx.Graph, G2:nx.Graph, node_aff_fn=None, edge_aff_fn=None, backend=None):
def build_aff_mat_from_networkx(G1: nx.Graph, G2: nx.Graph, node_aff_fn=None, edge_aff_fn=None, backend=None):
r"""
Convert networkx object to Adjacency matrix
Convert networkx object to affinity matrix
:param G1: networkx object, whose type must be networkx.Graph
:param G2: networkx object, whose type must be networkx.Graph
Expand Down Expand Up @@ -1381,7 +1381,7 @@ def build_aff_mat_from_networkx(G1:nx.Graph, G2:nx.Graph, node_aff_fn=None, edge
# Obtain Affinity Matrix
>>> K = pygm.utils.build_aff_mat_from_networkx(G1, G2)
>>> K.shape
(20,20)
(20, 20)
# The affinity matrices K can be further processed by GM solvers
"""
Expand All @@ -1397,7 +1397,7 @@ def build_aff_mat_from_networkx(G1:nx.Graph, G2:nx.Graph, node_aff_fn=None, edge

def build_aff_mat_from_graphml(G1_path, G2_path, node_aff_fn=None, edge_aff_fn=None, backend=None):
r"""
Convert networkx object to Adjacency matrix
Convert networkx object to affinity matrix
:param G1_path: The file path of the graphml object
:param G2_path: The file path of the graphml object
Expand Down Expand Up @@ -1427,7 +1427,7 @@ def build_aff_mat_from_graphml(G1_path, G2_path, node_aff_fn=None, edge_aff_fn=N
# Obtain Affinity Matrix
>>> K = pygm.utils.build_aff_mat_from_graphml(G1_path, G2_path)
>>> K.shape
(121,121)
(121, 121)
# The affinity matrices K can be further processed by GM solvers
"""
Expand All @@ -1440,10 +1440,71 @@ def build_aff_mat_from_graphml(G1_path, G2_path, node_aff_fn=None, edge_aff_fn=N
K = build_aff_mat(None, edge1, conn1, None, edge2, conn2, node_aff_fn=node_aff_fn, edge_aff_fn=edge_aff_fn, backend=backend)
return K


def build_aff_mat_from_pyg(G1, G2, node_aff_fn=None, edge_aff_fn=None, backend=None):
r"""
Convert torch_geometric.data.Data object to affinity matrix
def from_networkx(G:nx.Graph):
:param G1: Graph object, whose type must be torch_geometric.data.Data
:param G2: Graph object, whose type must be torch_geometric.data.Data
:param node_aff_fn: (default: inner_prod_aff_fn) the node affinity function with the characteristic
``node_aff_fn(2D Tensor, 2D Tensor) -> 2D Tensor``, which accepts two node feature tensors and
outputs the node-wise affinity tensor. See :func:`~pygmtools.utils.inner_prod_aff_fn` as an
example.
:param edge_aff_fn: (default: inner_prod_aff_fn) the edge affinity function with the characteristic
``edge_aff_fn(2D Tensor, 2D Tensor) -> 2D Tensor``, which accepts two edge feature tensors and
outputs the edge-wise affinity tensor. See :func:`~pygmtools.utils.inner_prod_aff_fn` as an
example.
:param backend: (default: ``pygmtools.BACKEND`` variable) the backend for computation.
:return: the affinity matrix corresponding to the torch_geometric.data.Data object G1 and G2
.. dropdown:: Example
::
>>> import torch
>>> from torch_geometric.data import Data
>>> import pygmtools as pygm
>>> pygm.set_backend('pytorch')
# Generate Graph object
>>> x1 = torch.rand((4, 2), dtype=torch.float)
>>> e1 = torch.tensor([[0, 0, 1, 1, 2, 2, 3], [1, 2, 0, 2, 0, 3, 1]], dtype=torch.long)
>>> G1 = Data(x=x1, edge_index=e1)
>>> x2 = torch.rand((5, 2), dtype=torch.float)
>>> e2 = torch.tensor([[0, 0, 1, 1, 2, 2, 3, 4, 4], [1, 3, 2, 3, 1, 3, 4, 2, 3]], dtype=torch.long)
>>> G2 = Data(x=x2, edge_index=e2)
# Obtain Affinity Matrix
>>> K = pygm.utils.build_aff_mat_from_pyg(G1, G2)
>>> K.shape
(20, 20)
# The affinity matrices K can be further processed by GM solvers
"""
from torch_geometric.data import Data
if type(G1) != Data:
raise ValueError("The type of G1 must be torch_geometric.data.Data")
if type(G2) != Data:
raise ValueError("The type of G2 must be torch_geometric.data.Data")
if backend is None:
backend = 'pytorch'
elif backend != 'pytorch':
raise ValueError("Function 'build_aff_mat_from_pyg' only supports pytorch backend.")
pygmtools.set_backend(backend)
node1 = G1.x
edge1 = G1.edge_attr.reshape(-1, 1) if G1.edge_attr is not None else None
conn1 = G1.edge_index.T if G1.edge_attr is not None else None
node2 = G2.x
edge2 = G2.edge_attr.reshape(-1, 1) if G2.edge_attr is not None else None
conn2 = G2.edge_index.T if G2.edge_attr is not None else None
K = build_aff_mat(node1, edge1, conn1, node2, edge2, conn2, node_aff_fn=node_aff_fn, edge_aff_fn=edge_aff_fn, backend=backend)
return K


def from_networkx(G: nx.Graph):
r"""
Convert networkx object to Adjacency matrix
Convert networkx object to adjacency matrix
:param G: networkx object, whose type must be networkx.Graph
:return: the adjacency matrix corresponding to the networkx object
Expand Down Expand Up @@ -1482,6 +1543,109 @@ def from_networkx(G:nx.Graph):
return adj_matrix


def from_graphml(filename):
r"""
Convert graphml object to adjacency matrix
:param filename: graphml file path
:return: the adjacency matrix corresponding to the graphml object
.. dropdown:: Example
::
>>> import pygmtools as pygm
>>> pygm.set_backend('numpy')
# example file (.graphml) path
>>> G1_path = 'examples/data/graph1.graphml'
>>> G2_path = 'examples/data/graph2.graphml'
# Obtain Adjacency matrix
>>> G1 = pygm.utils.from_graphml(G1_path)
>>> G1.shape
(11,11)
>>> G1 = pygm.utils.from_graphml(G2_path)
>>> G2.shape
(11, 11)
"""
if not filename.endswith('.graphml'):
raise ValueError("File name should end with '.graphml'")
if not os.path.isfile(filename):
raise ValueError("File not found: {}".format(filename))
return from_networkx(nx.read_graphml(filename))


def from_pyg(G):
r"""
Convert torch_geometric.data.Data object to adjacency matrix
:param G: Graph object, whose type must be torch_geometric.data.Data
:return: the adjacency matrix corresponding to the torch_geometric.data.Data
.. dropdown:: Example
::
>>> import torch
>>> from torch_geometric.data import Data
>>> import pygmtools as pygm
>>> pygm.set_backend('pytorch')
# Generate Graph object (edge_attr is 1D edge weights)
>>> edge_index = torch.tensor([[0, 0, 1, 1, 2, 2, 3], [1, 2, 0, 2, 0, 3, 1]], dtype=torch.long)
>>> edge_attr = torch.rand((7), dtype=torch.float)
>>> G = Data(edge_index=edge_index, edge_attr=edge_attr)
>>> G
Data(edge_index=[2, 7], edge_attr=[7])
# Obtain Adjacency matrix
>>> pygm.utils.from_pyg(G)
tensor([[0.0000, 0.2872, 0.5249, 0.0000],
[0.5386, 0.0000, 0.8801, 0.0000],
[0.0966, 0.0000, 0.0000, 0.9825],
[0.0000, 0.4994, 0.0000, 0.0000]])
# Generate Graph object (edge_attr is multi-dimensional edge features)
>>> edge_index = torch.tensor([[0, 0, 1, 1, 2, 2, 3], [1, 2, 0, 2, 0, 3, 1]], dtype=torch.long)
>>> edge_attr = torch.rand((7, 5), dtype=torch.float)
>>> G = Data(edge_index=edge_index, edge_attr=edge_attr)
>>> G
Data(edge_index=[2, 7], edge_attr=[7, 5])
# Obtain Adjacency matrix
>>> pygm.utils.from_pyg(G)
tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.3776, 0.8405, 0.3963, 0.6111, 0.6220],
[0.4824, 0.6115, 0.5169, 0.2558, 0.8300],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
[[0.4206, 0.4795, 0.0512, 0.1543, 0.0133],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.1053, 0.9634, 0.1822, 0.8167, 0.4903],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
[[0.5127, 0.5046, 0.7905, 0.9613, 0.4695],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.5535, 0.1592, 0.0363, 0.2447, 0.7754]],
[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.9172, 0.6820, 0.7201, 0.4397, 0.0732],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]])
"""
from torch_geometric.utils import to_dense_adj
from torch_geometric.data import Data
if type(G) != Data:
raise ValueError("The type of G must be torch_geometric.data.Data")
if G.edge_attr is not None and G.edge_attr.dim == 2 and G.edge_attr.shape[0] == 1:
G.edge_attr = G.edge_attr[0]
return to_dense_adj(edge_index=G.edge_index, edge_attr=G.edge_attr)[0]


def to_networkx(adj_matrix, backend=None):
"""
Convert adjacency matrix to NetworkX object
Expand Down Expand Up @@ -1520,40 +1684,6 @@ def to_networkx(adj_matrix, backend=None):
return G


def from_graphml(filename):
r"""
Convert graphml object to Adjacency matrix
:param filename: graphml file path
:return: the adjacency matrix corresponding to the graphml object
.. dropdown:: Example
::
>>> import pygmtools as pygm
>>> pygm.set_backend('numpy')
# example file (.graphml) path
>>> G1_path = 'examples/data/graph1.graphml'
>>> G2_path = 'examples/data/graph2.graphml'
# Obtain Adjacency matrix
>>> G1 = pygm.utils.from_graphml(G1_path)
>>> G1.shape
(11,11)
>>> G1 = pygm.utils.from_graphml(G2_path)
>>> G2.shape
(11,11)
"""
if not filename.endswith('.graphml'):
raise ValueError("File name should end with '.graphml'")
if not os.path.isfile(filename):
raise ValueError("File not found: {}".format(filename))
return from_networkx(nx.read_graphml(filename))


def to_graphml(adj_matrix, filename, backend=None):
r"""
Write an adjacency matrix to a GraphML file
Expand Down Expand Up @@ -1590,4 +1720,59 @@ def to_graphml(adj_matrix, filename, backend=None):
[0.15422904, 0.64656912, 0.93219422, 0.784769 ]])
"""
nx.write_graphml(to_networkx(adj_matrix, backend), filename)


def to_pyg(adj_matrix, edge_attr=None, backend=None):
"""
Convert adjacency matrix to torch_geometric.data.Data object
:param adj_matrix: the adjacency matrix to convert, whose type must be torch.Tensor,
it can be 2D matrix (num_nodes, num_nodes) or
3D matrix (num_nodes, num_nodes, num_edge_features)
:param backend: (default: ``pygmtools.BACKEND`` variable) the backend for computation.
:return: the torch_geometric.data.Data object corresponding to the adjacency matrix
.. dropdown:: Example
::
>>> import torch
>>> import pygmtools as pygm
>>> pygm.set_backend('pytorch')
# Generate 2D adjacency matrix (num_nodes, num_nodes)
>>> adj_matrix = torch.rand((4, 4))
# Obtain torch_geometric.data.Data object
>>> pygm.utils.to_pyg(adj_matrix)
Data(edge_index=[2, 16], edge_attr=[16])
# Generate 3D adjacency matrix (num_nodes, num_nodes, num_edge_features)
>>> adj_matrix = torch.rand((4, 4, 3))
# Obtain torch_geometric.data.Data object
>>> pygm.utils.to_pyg(adj_matrix)
Data(edge_index=[2, 16], edge_attr=[16, 3])
"""
import torch
from torch_geometric.data import Data

if backend is None:
backend = 'pytorch'
elif backend != 'pytorch':
raise ValueError("Function 'build_aff_mat_from_pyg' only supports pytorch backend.")
pygmtools.set_backend(backend)
if type(adj_matrix) != torch.Tensor:
raise ValueError("The type of adj_matrix must be torch.Tensor")

if adj_matrix.ndim == 2:
edge_index, edge_attr = dense_to_sparse(adj_matrix, backend=backend)
edge_attr = edge_attr.reshape(-1)
else:
adj = (adj_matrix != 0).any(dim=-1).float()
edge_index, _ = dense_to_sparse(adj, backend=backend)
conn = edge_index.T
edge_attr = adj_matrix[conn[0], conn[1]]

G = Data(x=None, edge_index=edge_index.T, edge_attr=edge_attr)
return G
1 change: 1 addition & 0 deletions tests/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ easydict>=1.7
paddlepaddle==2.4.1
protobuf==3.19.5
torch
torch_geometric
tqdm
jittor==1.3.5.37
appdirs>=1.4.4
Expand Down
1 change: 1 addition & 0 deletions tests/requirements_win_mac.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ easydict>=1.7
paddlepaddle==2.4.1
protobuf==3.19.5
torch
torch_geometric
tqdm
appdirs>=1.4.4
tensorflow==2.9.3
Expand Down
Loading

0 comments on commit c537694

Please sign in to comment.