Skip to content

Commit

Permalink
[Transform] Add SVDPE Transform Module (#5121)
Browse files Browse the repository at this point in the history
* add SVD positional encoding

* modify importing module

* Fixed certain problems

* Change the test unit to a nonsigular one

* Fixed typo and make accord with lintrunner

* added svd_pe into dgl.rst

* Modified dgl.rst
  • Loading branch information
ZhenyuLU-Heliodore authored and czkkkkkk committed Apr 19, 2023
1 parent d9da420 commit 394c0fd
Show file tree
Hide file tree
Showing 5 changed files with 701 additions and 259 deletions.
3 changes: 1 addition & 2 deletions docs/source/api/python/dgl.rst
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,13 @@ Operators for generating new graphs by manipulating the structure of the existin
khop_graph
knn_graph
laplacian_lambda_max
laplacian_pe
line_graph
metapath_reachable_graph
metis_partition
metis_partition_assignment
norm_by_dst
partition_graph_with_halo
radius_graph
random_walk_pe
remove_edges
remove_nodes
remove_self_loop
Expand Down Expand Up @@ -116,6 +114,7 @@ Operators for generating positional encodings of each node.
laplacian_pe
double_radius_node_labeling
shortest_dist
svd_pe

.. _api-partition:

Expand Down
1 change: 1 addition & 0 deletions docs/source/api/python/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,4 @@ dgl.transforms
RowFeatNormalizer
SIGNDiffusion
ToLevi
SVDPE
80 changes: 80 additions & 0 deletions python/dgl/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
'to_double',
'double_radius_node_labeling',
'shortest_dist',
'svd_pe'
]


Expand Down Expand Up @@ -3913,4 +3914,83 @@ def _get_nodes(pred, i, j):
return F.copy_to(F.tensor(dist, dtype=F.int64), g.device), \
F.copy_to(F.tensor(paths, dtype=F.int64), g.device)


def svd_pe(g, k, padding=False, random_flip=True):
r"""SVD-based Positional Encoding, as introduced in
`Global Self-Attention as a Replacement for Graph Convolution
<https://arxiv.org/pdf/2108.03348.pdf>`__
This function computes the largest :math:`k` singular values and
corresponding left and right singular vectors to form positional encodings.
Parameters
----------
g : DGLGraph
A DGLGraph to be encoded, which must be a homogeneous one.
k : int
Number of largest singular values and corresponding singular vectors
used for positional encoding.
padding : bool, optional
If False, raise an error when :math:`k > N`,
where :math:`N` is the number of nodes in :attr:`g`.
If True, add zero paddings in the end of encoding vectors when
:math:`k > N`.
Default : False.
random_flip : bool, optional
If True, randomly flip the signs of encoding vectors.
Proposed to be activated during training for better generalization.
Default : True.
Returns
-------
Tensor
Return SVD-based positional encodings of shape :math:`(N, 2k)`.
Example
-------
>>> import dgl
>>> g = dgl.graph(([0,1,2,3,4,2,3,1,4,0], [2,3,1,4,0,0,1,2,3,4]))
>>> dgl.svd_pe(g, k=2, padding=False, random_flip=True)
tensor([[-6.3246e-01, -1.1373e-07, -6.3246e-01, 0.0000e+00],
[-6.3246e-01, 7.6512e-01, -6.3246e-01, -7.6512e-01],
[ 6.3246e-01, 4.7287e-01, 6.3246e-01, -4.7287e-01],
[-6.3246e-01, -7.6512e-01, -6.3246e-01, 7.6512e-01],
[ 6.3246e-01, -4.7287e-01, 6.3246e-01, 4.7287e-01]])
"""
n = g.num_nodes()
if not padding and n < k:
raise ValueError(
"The number of singular values k must be no greater than the "
"number of nodes n, but " +
f"got {k} and {n} respectively."
)
a = g.adj(ctx=g.device, scipy_fmt="coo").toarray()
u, d, vh = scipy.linalg.svd(a)
v = vh.transpose()
m = min(n, k)
topm_u = u[:, 0:m]
topm_v = v[:, 0:m]
topm_sqrt_d = sparse.diags(np.sqrt(d[0:m]))
encoding = np.concatenate(
((topm_u @ topm_sqrt_d), (topm_v @ topm_sqrt_d)), axis=1
)
# randomly flip row vectors
if random_flip:
rand_sign = 2 * (np.random.rand(n) > 0.5) - 1
flipped_encoding = F.tensor(
rand_sign[:, np.newaxis] * encoding, dtype=F.float32
)
else:
flipped_encoding = F.tensor(encoding, dtype=F.float32)

if n < k:
zero_padding = F.zeros(
[n, 2 * (k - n)], dtype=F.float32, ctx=F.context(flipped_encoding)
)
flipped_encoding = F.cat([flipped_encoding, zero_padding], dim=1)

return flipped_encoding


_init_api("dgl.transform", __name__)
60 changes: 59 additions & 1 deletion python/dgl/transforms/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@
'DropEdge',
'AddEdge',
'SIGNDiffusion',
'ToLevi'
'ToLevi',
'SVDPE'
]

def update_graph_structure(g, data_dict, copy_edata=True):
Expand Down Expand Up @@ -1788,3 +1789,60 @@ def __call__(self, g):
utils.set_new_frames(levi_g, node_frames=edge_frames+node_frames)

return levi_g


class SVDPE(BaseTransform):
r"""SVD-based Positional Encoding, as introduced in
`Global Self-Attention as a Replacement for Graph Convolution
<https://arxiv.org/pdf/2108.03348.pdf>`__
This function computes the largest :math:`k` singular values and
corresponding left and right singular vectors to form positional encodings,
which could be stored in ndata.
Parameters
----------
k : int
Number of largest singular values and corresponding singular vectors
used for positional encoding.
feat_name : str, optional
Name to store the computed positional encodings in ndata.
Default : ``svd_pe``
padding : bool, optional
If False, raise an error when :math:`k > N`,
where :math:`N` is the number of nodes in :attr:`g`.
If True, add zero paddings in the end of encodings when :math:`k > N`.
Default : False.
random_flip : bool, optional
If True, randomly flip the signs of encoding vectors.
Proposed to be activated during training for better generalization.
Default : True.
Example
-------
>>> import dgl
>>> from dgl import SVDPE
>>> transform = SVDPE(k=2, feat_name="svd_pe")
>>> g = dgl.graph(([0,1,2,3,4,2,3,1,4,0], [2,3,1,4,0,0,1,2,3,4]))
>>> g_ = transform(g)
>>> print(g_.ndata['svd_pe'])
tensor([[-6.3246e-01, -1.1373e-07, -6.3246e-01, 0.0000e+00],
[-6.3246e-01, 7.6512e-01, -6.3246e-01, -7.6512e-01],
[ 6.3246e-01, 4.7287e-01, 6.3246e-01, -4.7287e-01],
[-6.3246e-01, -7.6512e-01, -6.3246e-01, 7.6512e-01],
[ 6.3246e-01, -4.7287e-01, 6.3246e-01, 4.7287e-01]])
"""
def __init__(self, k, feat_name="svd_pe", padding=False, random_flip=True):
self.k = k
self.feat_name = feat_name
self.padding = padding
self.random_flip = random_flip

def __call__(self, g):
encoding = functional.svd_pe(
g, k=self.k, padding=self.padding, random_flip=self.random_flip
)
g.ndata[self.feat_name] = F.copy_to(encoding, g.device)

return g
Loading

0 comments on commit 394c0fd

Please sign in to comment.