diff --git a/docs/source/api/python/dgl.rst b/docs/source/api/python/dgl.rst index 0b69915b6a1c..c370ee202cc0 100644 --- a/docs/source/api/python/dgl.rst +++ b/docs/source/api/python/dgl.rst @@ -72,7 +72,6 @@ Operators for generating new graphs by manipulating the structure of the existin khop_graph knn_graph laplacian_lambda_max - laplacian_pe line_graph metapath_reachable_graph metis_partition @@ -80,7 +79,6 @@ Operators for generating new graphs by manipulating the structure of the existin norm_by_dst partition_graph_with_halo radius_graph - random_walk_pe remove_edges remove_nodes remove_self_loop @@ -116,6 +114,7 @@ Operators for generating positional encodings of each node. laplacian_pe double_radius_node_labeling shortest_dist + svd_pe .. _api-partition: diff --git a/docs/source/api/python/transforms.rst b/docs/source/api/python/transforms.rst index b21066e8328a..6a74cae5572c 100644 --- a/docs/source/api/python/transforms.rst +++ b/docs/source/api/python/transforms.rst @@ -34,3 +34,4 @@ dgl.transforms RowFeatNormalizer SIGNDiffusion ToLevi + SVDPE diff --git a/python/dgl/transforms/functional.py b/python/dgl/transforms/functional.py index 57d8b47fda95..d70bd6a5431e 100644 --- a/python/dgl/transforms/functional.py +++ b/python/dgl/transforms/functional.py @@ -81,6 +81,7 @@ 'to_double', 'double_radius_node_labeling', 'shortest_dist', + 'svd_pe' ] @@ -3913,4 +3914,83 @@ def _get_nodes(pred, i, j): return F.copy_to(F.tensor(dist, dtype=F.int64), g.device), \ F.copy_to(F.tensor(paths, dtype=F.int64), g.device) + +def svd_pe(g, k, padding=False, random_flip=True): + r"""SVD-based Positional Encoding, as introduced in + `Global Self-Attention as a Replacement for Graph Convolution + `__ + + This function computes the largest :math:`k` singular values and + corresponding left and right singular vectors to form positional encodings. + + Parameters + ---------- + g : DGLGraph + A DGLGraph to be encoded, which must be a homogeneous one. + k : int + Number of largest singular values and corresponding singular vectors + used for positional encoding. + padding : bool, optional + If False, raise an error when :math:`k > N`, + where :math:`N` is the number of nodes in :attr:`g`. + If True, add zero paddings in the end of encoding vectors when + :math:`k > N`. + Default : False. + random_flip : bool, optional + If True, randomly flip the signs of encoding vectors. + Proposed to be activated during training for better generalization. + Default : True. + + Returns + ------- + Tensor + Return SVD-based positional encodings of shape :math:`(N, 2k)`. + + Example + ------- + >>> import dgl + + >>> g = dgl.graph(([0,1,2,3,4,2,3,1,4,0], [2,3,1,4,0,0,1,2,3,4])) + >>> dgl.svd_pe(g, k=2, padding=False, random_flip=True) + tensor([[-6.3246e-01, -1.1373e-07, -6.3246e-01, 0.0000e+00], + [-6.3246e-01, 7.6512e-01, -6.3246e-01, -7.6512e-01], + [ 6.3246e-01, 4.7287e-01, 6.3246e-01, -4.7287e-01], + [-6.3246e-01, -7.6512e-01, -6.3246e-01, 7.6512e-01], + [ 6.3246e-01, -4.7287e-01, 6.3246e-01, 4.7287e-01]]) + """ + n = g.num_nodes() + if not padding and n < k: + raise ValueError( + "The number of singular values k must be no greater than the " + "number of nodes n, but " + + f"got {k} and {n} respectively." + ) + a = g.adj(ctx=g.device, scipy_fmt="coo").toarray() + u, d, vh = scipy.linalg.svd(a) + v = vh.transpose() + m = min(n, k) + topm_u = u[:, 0:m] + topm_v = v[:, 0:m] + topm_sqrt_d = sparse.diags(np.sqrt(d[0:m])) + encoding = np.concatenate( + ((topm_u @ topm_sqrt_d), (topm_v @ topm_sqrt_d)), axis=1 + ) + # randomly flip row vectors + if random_flip: + rand_sign = 2 * (np.random.rand(n) > 0.5) - 1 + flipped_encoding = F.tensor( + rand_sign[:, np.newaxis] * encoding, dtype=F.float32 + ) + else: + flipped_encoding = F.tensor(encoding, dtype=F.float32) + + if n < k: + zero_padding = F.zeros( + [n, 2 * (k - n)], dtype=F.float32, ctx=F.context(flipped_encoding) + ) + flipped_encoding = F.cat([flipped_encoding, zero_padding], dim=1) + + return flipped_encoding + + _init_api("dgl.transform", __name__) diff --git a/python/dgl/transforms/module.py b/python/dgl/transforms/module.py index 6c2d5ddc5395..e11b5d3547bb 100644 --- a/python/dgl/transforms/module.py +++ b/python/dgl/transforms/module.py @@ -54,7 +54,8 @@ 'DropEdge', 'AddEdge', 'SIGNDiffusion', - 'ToLevi' + 'ToLevi', + 'SVDPE' ] def update_graph_structure(g, data_dict, copy_edata=True): @@ -1788,3 +1789,60 @@ def __call__(self, g): utils.set_new_frames(levi_g, node_frames=edge_frames+node_frames) return levi_g + + +class SVDPE(BaseTransform): + r"""SVD-based Positional Encoding, as introduced in + `Global Self-Attention as a Replacement for Graph Convolution + `__ + + This function computes the largest :math:`k` singular values and + corresponding left and right singular vectors to form positional encodings, + which could be stored in ndata. + + Parameters + ---------- + k : int + Number of largest singular values and corresponding singular vectors + used for positional encoding. + feat_name : str, optional + Name to store the computed positional encodings in ndata. + Default : ``svd_pe`` + padding : bool, optional + If False, raise an error when :math:`k > N`, + where :math:`N` is the number of nodes in :attr:`g`. + If True, add zero paddings in the end of encodings when :math:`k > N`. + Default : False. + random_flip : bool, optional + If True, randomly flip the signs of encoding vectors. + Proposed to be activated during training for better generalization. + Default : True. + + Example + ------- + >>> import dgl + >>> from dgl import SVDPE + + >>> transform = SVDPE(k=2, feat_name="svd_pe") + >>> g = dgl.graph(([0,1,2,3,4,2,3,1,4,0], [2,3,1,4,0,0,1,2,3,4])) + >>> g_ = transform(g) + >>> print(g_.ndata['svd_pe']) + tensor([[-6.3246e-01, -1.1373e-07, -6.3246e-01, 0.0000e+00], + [-6.3246e-01, 7.6512e-01, -6.3246e-01, -7.6512e-01], + [ 6.3246e-01, 4.7287e-01, 6.3246e-01, -4.7287e-01], + [-6.3246e-01, -7.6512e-01, -6.3246e-01, 7.6512e-01], + [ 6.3246e-01, -4.7287e-01, 6.3246e-01, 4.7287e-01]]) + """ + def __init__(self, k, feat_name="svd_pe", padding=False, random_flip=True): + self.k = k + self.feat_name = feat_name + self.padding = padding + self.random_flip = random_flip + + def __call__(self, g): + encoding = functional.svd_pe( + g, k=self.k, padding=self.padding, random_flip=self.random_flip + ) + g.ndata[self.feat_name] = F.copy_to(encoding, g.device) + + return g diff --git a/tests/compute/test_transform.py b/tests/compute/test_transform.py index b6fe96b065bd..5737c6020e8f 100644 --- a/tests/compute/test_transform.py +++ b/tests/compute/test_transform.py @@ -28,10 +28,12 @@ from test_utils.graph_cases import get_cases from test_utils import parametrize_idtype -from test_heterograph import create_test_heterograph3, create_test_heterograph4, create_test_heterograph5 +from test_heterograph import create_test_heterograph3, create_test_heterograph4, \ + create_test_heterograph5 D = 5 + # line graph related def test_line_graph1(): @@ -43,10 +45,11 @@ def test_line_graph1(): assert F.allclose(L.ndata['h'], G.edata['h']) assert G.device == F.ctx() + @parametrize_idtype def test_line_graph2(idtype): g = dgl.heterograph({ - ('user', 'follows', 'user'): ([0, 1, 1, 2, 2],[2, 0, 2, 0, 1]) + ('user', 'follows', 'user'): ([0, 1, 1, 2, 2], [2, 0, 2, 0, 1]) }, idtype=idtype) lg = dgl.line_graph(g) assert lg.number_of_nodes() == 5 @@ -66,7 +69,7 @@ def test_line_graph2(idtype): assert np.array_equal(F.asnumpy(col), np.array([4, 0, 3, 1])) g = dgl.heterograph({ - ('user', 'follows', 'user'): ([0, 1, 1, 2, 2],[2, 0, 2, 0, 1]) + ('user', 'follows', 'user'): ([0, 1, 1, 2, 2], [2, 0, 2, 0, 1]) }, idtype=idtype).formats('csr') lg = dgl.line_graph(g) assert lg.number_of_nodes() == 5 @@ -78,7 +81,7 @@ def test_line_graph2(idtype): np.array([3, 4, 0, 3, 4, 0, 1, 2])) g = dgl.heterograph({ - ('user', 'follows', 'user'): ([0, 1, 1, 2, 2],[2, 0, 2, 0, 1]) + ('user', 'follows', 'user'): ([0, 1, 1, 2, 2], [2, 0, 2, 0, 1]) }, idtype=idtype).formats('csc') lg = dgl.line_graph(g) assert lg.number_of_nodes() == 5 @@ -93,6 +96,7 @@ def test_line_graph2(idtype): assert np.array_equal(col[order], np.array([3, 4, 0, 3, 4, 0, 1, 2])) + def test_no_backtracking(): N = 5 G = dgl.DGLGraph(nx.star_graph(N)) @@ -104,6 +108,7 @@ def test_no_backtracking(): assert not L.has_edges_between(e1, e2) assert not L.has_edges_between(e2, e1) + # reverse graph related @parametrize_idtype def test_reverse(idtype): @@ -168,14 +173,16 @@ def test_reverse(idtype): # test heterogeneous graph g = dgl.heterograph({ - ('user', 'follows', 'user'): ([0, 1, 2, 4, 3 ,1, 3], [1, 2, 3, 2, 0, 0, 1]), - ('user', 'plays', 'game'): ([0, 0, 2, 3, 3, 4, 1], [1, 0, 1, 0, 1, 0, 0]), + ('user', 'follows', 'user'): ( + [0, 1, 2, 4, 3, 1, 3], [1, 2, 3, 2, 0, 0, 1]), + ('user', 'plays', 'game'): ( + [0, 0, 2, 3, 3, 4, 1], [1, 0, 1, 0, 1, 0, 0]), ('developer', 'develops', 'game'): ([0, 1, 1, 2], [0, 0, 1, 1])}, idtype=idtype, device=F.ctx()) g.nodes['user'].data['h'] = F.tensor([0, 1, 2, 3, 4]) g.nodes['user'].data['hh'] = F.tensor([1, 1, 1, 1, 1]) g.nodes['game'].data['h'] = F.tensor([0, 1]) - g.edges['follows'].data['h'] = F.tensor([0, 1, 2, 4, 3 ,1, 3]) + g.edges['follows'].data['h'] = F.tensor([0, 1, 2, 4, 3, 1, 3]) g.edges['follows'].data['hh'] = F.tensor([1, 2, 3, 2, 0, 0, 1]) g_r = dgl.reverse(g) @@ -187,21 +194,27 @@ def test_reverse(idtype): for ntype in g.ntypes: assert g.number_of_nodes(ntype) == g_r.number_of_nodes(ntype) assert F.array_equal(g.nodes['user'].data['h'], g_r.nodes['user'].data['h']) - assert F.array_equal(g.nodes['user'].data['hh'], g_r.nodes['user'].data['hh']) + assert F.array_equal(g.nodes['user'].data['hh'], + g_r.nodes['user'].data['hh']) assert F.array_equal(g.nodes['game'].data['h'], g_r.nodes['game'].data['h']) assert len(g_r.edges['follows'].data) == 0 - u_g, v_g, eids_g = g.all_edges(form='all', etype=('user', 'follows', 'user')) - u_rg, v_rg, eids_rg = g_r.all_edges(form='all', etype=('user', 'follows', 'user')) + u_g, v_g, eids_g = g.all_edges(form='all', + etype=('user', 'follows', 'user')) + u_rg, v_rg, eids_rg = g_r.all_edges(form='all', + etype=('user', 'follows', 'user')) assert F.array_equal(u_g, v_rg) assert F.array_equal(v_g, u_rg) assert F.array_equal(eids_g, eids_rg) u_g, v_g, eids_g = g.all_edges(form='all', etype=('user', 'plays', 'game')) - u_rg, v_rg, eids_rg = g_r.all_edges(form='all', etype=('game', 'plays', 'user')) + u_rg, v_rg, eids_rg = g_r.all_edges(form='all', + etype=('game', 'plays', 'user')) assert F.array_equal(u_g, v_rg) assert F.array_equal(v_g, u_rg) assert F.array_equal(eids_g, eids_rg) - u_g, v_g, eids_g = g.all_edges(form='all', etype=('developer', 'develops', 'game')) - u_rg, v_rg, eids_rg = g_r.all_edges(form='all', etype=('game', 'develops', 'developer')) + u_g, v_g, eids_g = g.all_edges(form='all', + etype=('developer', 'develops', 'game')) + u_rg, v_rg, eids_rg = g_r.all_edges(form='all', + etype=('game', 'develops', 'developer')) assert F.array_equal(u_g, v_rg) assert F.array_equal(v_g, u_rg) assert F.array_equal(eids_g, eids_rg) @@ -225,8 +238,10 @@ def test_reverse(idtype): assert etype_g[1] == etype_gr[1] assert etype_g[2] == etype_gr[0] assert g.number_of_edges(etype_g) == g_r.number_of_edges(etype_gr) - assert F.array_equal(g.edges['follows'].data['h'], g_r.edges['follows'].data['h']) - assert F.array_equal(g.edges['follows'].data['hh'], g_r.edges['follows'].data['hh']) + assert F.array_equal(g.edges['follows'].data['h'], + g_r.edges['follows'].data['h']) + assert F.array_equal(g.edges['follows'].data['hh'], + g_r.edges['follows'].data['hh']) # add new node feature to g_r g_r.nodes['user'].data['hhh'] = F.tensor([0, 1, 2, 3, 4]) @@ -254,6 +269,7 @@ def test_reverse_shared_frames(idtype): assert F.allclose(g.edges[[0, 2], [1, 1]].data['h'], rg.edges[[1, 1], [0, 2]].data['h']) + @unittest.skipIf(F._default_context_str == 'gpu', reason="GPU not implemented") def test_to_bidirected(): # homogeneous graph @@ -271,7 +287,7 @@ def test_to_bidirected(): # heterogeneous graph elist1 = [(0, 0), (0, 1), (1, 0), - (1, 1), (2, 1), (2, 2)] + (1, 1), (2, 1), (2, 2)] elist2 = [(0, 0), (0, 1)] g = dgl.heterograph({ ('user', 'wins', 'user'): tuple(zip(*elist1)), @@ -295,6 +311,7 @@ def test_to_bidirected(): big = dgl.to_bidirected(g, copy_ndata=True) assert F.array_equal(g.nodes['user'].data['h'], big.nodes['user'].data['h']) + def test_add_reverse_edges(): # homogeneous graph g = dgl.graph((F.tensor([0, 1, 3, 1]), F.tensor([1, 2, 0, 2]))) @@ -306,7 +323,8 @@ def test_add_reverse_edges(): assert F.array_equal(F.cat([u, v], dim=0), ub) assert F.array_equal(F.cat([v, u], dim=0), vb) assert F.array_equal(g.ndata['h'], bg.ndata['h']) - assert F.array_equal(F.cat([g.edata['h'], g.edata['h']], dim=0), bg.edata['h']) + assert F.array_equal(F.cat([g.edata['h'], g.edata['h']], dim=0), + bg.edata['h']) bg.ndata['hh'] = F.tensor([[0.], [1.], [2.], [1.]]) assert ('hh' in g.ndata) is False bg.edata['hh'] = F.tensor([[0.], [1.], [2.], [1.], [0.], [1.], [2.], [1.]]) @@ -322,26 +340,32 @@ def test_add_reverse_edges(): # zero edge graph g = dgl.graph(([], [])) - bg = dgl.add_reverse_edges(g, copy_ndata=True, copy_edata=True, exclude_self=False) + bg = dgl.add_reverse_edges(g, copy_ndata=True, copy_edata=True, + exclude_self=False) # heterogeneous graph g = dgl.heterograph({ - ('user', 'wins', 'user'): (F.tensor([0, 2, 0, 2, 2]), F.tensor([1, 1, 2, 1, 0])), + ('user', 'wins', 'user'): ( + F.tensor([0, 2, 0, 2, 2]), F.tensor([1, 1, 2, 1, 0])), ('user', 'plays', 'game'): (F.tensor([1, 2, 1]), F.tensor([2, 1, 1])), ('user', 'follows', 'user'): (F.tensor([1, 2, 1]), F.tensor([0, 0, 0])) }) g.nodes['game'].data['hv'] = F.ones((3, 1)) g.nodes['user'].data['hv'] = F.ones((3, 1)) g.edges['wins'].data['h'] = F.tensor([0, 1, 2, 3, 4]) - bg = dgl.add_reverse_edges(g, copy_ndata=True, copy_edata=True, ignore_bipartite=True) - assert F.array_equal(g.nodes['game'].data['hv'], bg.nodes['game'].data['hv']) - assert F.array_equal(g.nodes['user'].data['hv'], bg.nodes['user'].data['hv']) + bg = dgl.add_reverse_edges(g, copy_ndata=True, copy_edata=True, + ignore_bipartite=True) + assert F.array_equal(g.nodes['game'].data['hv'], + bg.nodes['game'].data['hv']) + assert F.array_equal(g.nodes['user'].data['hv'], + bg.nodes['user'].data['hv']) u, v = g.all_edges(order='eid', etype=('user', 'wins', 'user')) ub, vb = bg.all_edges(order='eid', etype=('user', 'wins', 'user')) assert F.array_equal(F.cat([u, v], dim=0), ub) assert F.array_equal(F.cat([v, u], dim=0), vb) - assert F.array_equal(F.cat([g.edges['wins'].data['h'], g.edges['wins'].data['h']], dim=0), - bg.edges['wins'].data['h']) + assert F.array_equal( + F.cat([g.edges['wins'].data['h'], g.edges['wins'].data['h']], dim=0), + bg.edges['wins'].data['h']) u, v = g.all_edges(order='eid', etype=('user', 'follows', 'user')) ub, vb = bg.all_edges(order='eid', etype=('user', 'follows', 'user')) assert F.array_equal(F.cat([u, v], dim=0), ub) @@ -354,7 +378,8 @@ def test_add_reverse_edges(): assert set(bg.edges['follows'].data.keys()) == {dgl.EID} # donot share ndata and edata - bg = dgl.add_reverse_edges(g, copy_ndata=False, copy_edata=False, ignore_bipartite=True) + bg = dgl.add_reverse_edges(g, copy_ndata=False, copy_edata=False, + ignore_bipartite=True) assert len(bg.edges['wins'].data) == 0 assert len(bg.edges['plays'].data) == 0 assert len(bg.edges['follows'].data) == 0 @@ -381,13 +406,16 @@ def test_add_reverse_edges(): bg = dgl.add_reverse_edges(g, copy_ndata=True, copy_edata=True) assert g.number_of_nodes() == bg.number_of_nodes() assert F.array_equal(g.ndata['h'], bg.ndata['h']) - assert F.array_equal(F.cat([g.edata['h'], g.edata['h']], dim=0), bg.edata['h']) + assert F.array_equal(F.cat([g.edata['h'], g.edata['h']], dim=0), + bg.edata['h']) # heterogeneous graph g = dgl.heterograph({ - ('user', 'wins', 'user'): (F.tensor([0, 2, 0, 2, 2]), F.tensor([1, 1, 2, 1, 0])), + ('user', 'wins', 'user'): ( + F.tensor([0, 2, 0, 2, 2]), F.tensor([1, 1, 2, 1, 0])), ('user', 'plays', 'game'): (F.tensor([1, 2, 1]), F.tensor([2, 1, 1])), - ('user', 'follows', 'user'): (F.tensor([1, 2, 1]), F.tensor([0, 0, 0]))}, + ('user', 'follows', 'user'): ( + F.tensor([1, 2, 1]), F.tensor([0, 0, 0]))}, num_nodes_dict={ 'user': 5, 'game': 3 @@ -395,13 +423,17 @@ def test_add_reverse_edges(): g.nodes['game'].data['hv'] = F.ones((3, 1)) g.nodes['user'].data['hv'] = F.ones((5, 1)) g.edges['wins'].data['h'] = F.tensor([0, 1, 2, 3, 4]) - bg = dgl.add_reverse_edges(g, copy_ndata=True, copy_edata=True, ignore_bipartite=True) + bg = dgl.add_reverse_edges(g, copy_ndata=True, copy_edata=True, + ignore_bipartite=True) assert g.number_of_nodes('user') == bg.number_of_nodes('user') assert g.number_of_nodes('game') == bg.number_of_nodes('game') - assert F.array_equal(g.nodes['game'].data['hv'], bg.nodes['game'].data['hv']) - assert F.array_equal(g.nodes['user'].data['hv'], bg.nodes['user'].data['hv']) - assert F.array_equal(F.cat([g.edges['wins'].data['h'], g.edges['wins'].data['h']], dim=0), - bg.edges['wins'].data['h']) + assert F.array_equal(g.nodes['game'].data['hv'], + bg.nodes['game'].data['hv']) + assert F.array_equal(g.nodes['user'].data['hv'], + bg.nodes['user'].data['hv']) + assert F.array_equal( + F.cat([g.edges['wins'].data['h'], g.edges['wins'].data['h']], dim=0), + bg.edges['wins'].data['h']) # test exclude_self g = dgl.heterograph({ @@ -414,6 +446,7 @@ def test_add_reverse_edges(): assert rg.num_edges('r2') == 4 assert F.array_equal(rg.edges['r1'].data['h'], F.tensor([0, 1, 2, 3, 1, 3])) + @unittest.skipIf(F._default_context_str == 'gpu', reason="GPU not implemented") def test_simple_graph(): elist = [(0, 1), (0, 2), (1, 2), (0, 1)] @@ -426,11 +459,12 @@ def test_simple_graph(): eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst)))) assert eset == set(elist) + @unittest.skipIf(F._default_context_str == 'gpu', reason="GPU not implemented") def _test_bidirected_graph(): def _test(in_readonly, out_readonly): elist = [(0, 0), (0, 1), (1, 0), - (1, 1), (2, 1), (2, 2)] + (1, 1), (2, 1), (2, 2)] num_edges = 7 g = dgl.DGLGraph(elist, readonly=in_readonly) elist.append((1, 2)) @@ -473,6 +507,7 @@ def _test(g): g = dgl.DGLGraph(nx.erdos_renyi_graph(N, 0.3, directed=True)) _test(g) + @unittest.skipIf(F._default_context_str == 'gpu', reason="GPU not implemented") def test_khop_adj(): N = 20 @@ -511,6 +546,7 @@ def test_laplacian_lambda_max(): assert l_max < 2 + eps ''' + def create_large_graph(num_nodes, idtype=F.int64): row = np.random.choice(num_nodes, num_nodes * 10) col = np.random.choice(num_nodes, num_nodes * 10) @@ -519,22 +555,28 @@ def create_large_graph(num_nodes, idtype=F.int64): return dgl.from_scipy(spm, idtype=idtype) + # Disabled since everything will be on heterogeneous graphs @unittest.skipIf(F._default_context_str == 'gpu', reason="GPU not implemented") def test_partition_with_halo(): g = create_large_graph(1000) node_part = np.random.choice(4, g.number_of_nodes()) - subgs, _, _ = dgl.transforms.partition_graph_with_halo(g, node_part, 2, reshuffle=True) + subgs, _, _ = dgl.transforms.partition_graph_with_halo(g, node_part, 2, + reshuffle=True) for part_id, subg in subgs.items(): node_ids = np.nonzero(node_part == part_id)[0] lnode_ids = np.nonzero(F.asnumpy(subg.ndata['inner_node']))[0] orig_nids = F.asnumpy(subg.ndata['orig_id'])[lnode_ids] assert np.all(np.sort(orig_nids) == node_ids) - assert np.all(F.asnumpy(subg.in_degrees(lnode_ids)) == F.asnumpy(g.in_degrees(orig_nids))) - assert np.all(F.asnumpy(subg.out_degrees(lnode_ids)) == F.asnumpy(g.out_degrees(orig_nids))) + assert np.all(F.asnumpy(subg.in_degrees(lnode_ids)) == F.asnumpy( + g.in_degrees(orig_nids))) + assert np.all(F.asnumpy(subg.out_degrees(lnode_ids)) == F.asnumpy( + g.out_degrees(orig_nids))) + @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet') -@unittest.skipIf(F._default_context_str == 'gpu', reason="METIS doesn't support GPU") +@unittest.skipIf(F._default_context_str == 'gpu', + reason="METIS doesn't support GPU") @parametrize_idtype def test_metis_partition(idtype): # TODO(zhengda) Metis fails to partition a small graph. @@ -552,11 +594,13 @@ def test_metis_partition(idtype): assert_fail = True assert assert_fail + def check_metis_partition_with_constraint(g): ntypes = np.zeros((g.number_of_nodes(),), dtype=np.int32) - ntypes[0:int(g.number_of_nodes()/4)] = 1 - ntypes[int(g.number_of_nodes()*3/4):] = 2 - subgs = dgl.transforms.metis_partition(g, 4, extra_cached_hops=1, balance_ntypes=ntypes) + ntypes[0:int(g.number_of_nodes() / 4)] = 1 + ntypes[int(g.number_of_nodes() * 3 / 4):] = 2 + subgs = dgl.transforms.metis_partition(g, 4, extra_cached_hops=1, + balance_ntypes=ntypes) if subgs is not None: for i in subgs: subg = subgs[i] @@ -566,7 +610,8 @@ def check_metis_partition_with_constraint(g): print('type1:', np.sum(sub_ntypes == 1)) print('type2:', np.sum(sub_ntypes == 2)) subgs = dgl.transforms.metis_partition(g, 4, extra_cached_hops=1, - balance_ntypes=ntypes, balance_edges=True) + balance_ntypes=ntypes, + balance_edges=True) if subgs is not None: for i in subgs: subg = subgs[i] @@ -576,6 +621,7 @@ def check_metis_partition_with_constraint(g): print('type1:', np.sum(sub_ntypes == 1)) print('type2:', np.sum(sub_ntypes == 2)) + def check_metis_partition(g, extra_hops): subgs = dgl.transforms.metis_partition(g, 4, extra_cached_hops=extra_hops) num_inner_nodes = 0 @@ -586,7 +632,8 @@ def check_metis_partition(g, extra_hops): ledge_ids = np.nonzero(F.asnumpy(subg.edata['inner_edge']))[0] num_inner_nodes += len(lnode_ids) num_inner_edges += len(ledge_ids) - assert np.sum(F.asnumpy(subg.ndata['part_id']) == part_id) == len(lnode_ids) + assert np.sum(F.asnumpy(subg.ndata['part_id']) == part_id) == len( + lnode_ids) assert num_inner_nodes == g.number_of_nodes() print(g.number_of_edges() - num_inner_edges) @@ -594,7 +641,8 @@ def check_metis_partition(g, extra_hops): return # partitions with node reshuffling - subgs = dgl.transforms.metis_partition(g, 4, extra_cached_hops=extra_hops, reshuffle=True) + subgs = dgl.transforms.metis_partition(g, 4, extra_cached_hops=extra_hops, + reshuffle=True) num_inner_nodes = 0 num_inner_edges = 0 edge_cnts = np.zeros((g.number_of_edges(),)) @@ -604,13 +652,15 @@ def check_metis_partition(g, extra_hops): ledge_ids = np.nonzero(F.asnumpy(subg.edata['inner_edge']))[0] num_inner_nodes += len(lnode_ids) num_inner_edges += len(ledge_ids) - assert np.sum(F.asnumpy(subg.ndata['part_id']) == part_id) == len(lnode_ids) + assert np.sum(F.asnumpy(subg.ndata['part_id']) == part_id) == len( + lnode_ids) nids = F.asnumpy(subg.ndata[dgl.NID]) # ensure the local node Ids are contiguous. parent_ids = F.asnumpy(subg.ndata[dgl.NID]) parent_ids = parent_ids[:len(lnode_ids)] - assert np.all(parent_ids == np.arange(parent_ids[0], parent_ids[-1] + 1)) + assert np.all( + parent_ids == np.arange(parent_ids[0], parent_ids[-1] + 1)) # count the local edges. parent_ids = F.asnumpy(subg.edata[dgl.EID])[ledge_ids] @@ -625,14 +675,17 @@ def check_metis_partition(g, extra_hops): old_neighs2 = g.predecessors(old_nid) # If this is an inner node, it should have the full neighborhood. if inner_node[nid]: - assert np.all(np.sort(F.asnumpy(old_neighs1)) == np.sort(F.asnumpy(old_neighs2))) + assert np.all(np.sort(F.asnumpy(old_neighs1)) == np.sort( + F.asnumpy(old_neighs2))) # Normally, local edges are only counted once. assert np.all(edge_cnts == 1) assert num_inner_nodes == g.number_of_nodes() print(g.number_of_edges() - num_inner_edges) -@unittest.skipIf(F._default_context_str == 'gpu', reason="It doesn't support GPU") + +@unittest.skipIf(F._default_context_str == 'gpu', + reason="It doesn't support GPU") def test_reorder_nodes(): g = create_large_graph(1000) new_nids = np.random.permutation(g.number_of_nodes()) @@ -667,6 +720,7 @@ def test_reorder_nodes(): old_neighs2 = g.predecessors(old_nid) assert np.all(np.sort(old_neighs1) == np.sort(F.asnumpy(old_neighs2))) + @parametrize_idtype def test_compact(idtype): g1 = dgl.heterograph({ @@ -704,7 +758,8 @@ def _check(g, new_g, induced_nodes): # Test default new_g1 = dgl.compact_graphs(g1) - induced_nodes = {ntype: new_g1.nodes[ntype].data[dgl.NID] for ntype in new_g1.ntypes} + induced_nodes = {ntype: new_g1.nodes[ntype].data[dgl.NID] for ntype in + new_g1.ntypes} induced_nodes = {k: F.asnumpy(v) for k, v in induced_nodes.items()} assert new_g1.idtype == idtype assert set(induced_nodes['user']) == set([1, 3, 5, 2, 7]) @@ -715,7 +770,8 @@ def _check(g, new_g, induced_nodes): new_g1 = dgl.compact_graphs( g1, always_preserve={'game': F.tensor([4, 7], idtype)}) assert new_g1.idtype == idtype - induced_nodes = {ntype: new_g1.nodes[ntype].data[dgl.NID] for ntype in new_g1.ntypes} + induced_nodes = {ntype: new_g1.nodes[ntype].data[dgl.NID] for ntype in + new_g1.ntypes} induced_nodes = {k: F.asnumpy(v) for k, v in induced_nodes.items()} assert set(induced_nodes['user']) == set([1, 3, 5, 2, 7]) assert set(induced_nodes['game']) == set([4, 5, 6, 7]) @@ -724,7 +780,8 @@ def _check(g, new_g, induced_nodes): # Test with always_preserve given a tensor new_g3 = dgl.compact_graphs( g3, always_preserve=F.tensor([1, 7], idtype)) - induced_nodes = {ntype: new_g3.nodes[ntype].data[dgl.NID] for ntype in new_g3.ntypes} + induced_nodes = {ntype: new_g3.nodes[ntype].data[dgl.NID] for ntype in + new_g3.ntypes} induced_nodes = {k: F.asnumpy(v) for k, v in induced_nodes.items()} assert new_g3.idtype == idtype @@ -733,7 +790,8 @@ def _check(g, new_g, induced_nodes): # Test multiple graphs new_g1, new_g2 = dgl.compact_graphs([g1, g2]) - induced_nodes = {ntype: new_g1.nodes[ntype].data[dgl.NID] for ntype in new_g1.ntypes} + induced_nodes = {ntype: new_g1.nodes[ntype].data[dgl.NID] for ntype in + new_g1.ntypes} induced_nodes = {k: F.asnumpy(v) for k, v in induced_nodes.items()} assert new_g1.idtype == idtype assert new_g2.idtype == idtype @@ -745,7 +803,8 @@ def _check(g, new_g, induced_nodes): # Test multiple graphs with always_preserve given a dict new_g1, new_g2 = dgl.compact_graphs( [g1, g2], always_preserve={'game': F.tensor([4, 7], dtype=idtype)}) - induced_nodes = {ntype: new_g1.nodes[ntype].data[dgl.NID] for ntype in new_g1.ntypes} + induced_nodes = {ntype: new_g1.nodes[ntype].data[dgl.NID] for ntype in + new_g1.ntypes} induced_nodes = {k: F.asnumpy(v) for k, v in induced_nodes.items()} assert new_g1.idtype == idtype assert new_g2.idtype == idtype @@ -757,7 +816,8 @@ def _check(g, new_g, induced_nodes): # Test multiple graphs with always_preserve given a tensor new_g3, new_g4 = dgl.compact_graphs( [g3, g4], always_preserve=F.tensor([1, 7], dtype=idtype)) - induced_nodes = {ntype: new_g3.nodes[ntype].data[dgl.NID] for ntype in new_g3.ntypes} + induced_nodes = {ntype: new_g3.nodes[ntype].data[dgl.NID] for ntype in + new_g3.ntypes} induced_nodes = {k: F.asnumpy(v) for k, v in induced_nodes.items()} assert new_g3.idtype == idtype @@ -767,7 +827,9 @@ def _check(g, new_g, induced_nodes): _check(g3, new_g3, induced_nodes) _check(g4, new_g4, induced_nodes) -@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU to simple not implemented") + +@unittest.skipIf(F._default_context_str == 'gpu', + reason="GPU to simple not implemented") @parametrize_idtype def test_to_simple(idtype): # homogeneous graph @@ -814,12 +876,14 @@ def test_to_simple(idtype): g = dgl.heterograph({ ('user', 'follow', 'user'): ([0, 1, 2, 1, 1, 1], [1, 3, 2, 3, 4, 4]), - ('user', 'plays', 'game'): ([3, 2, 1, 1, 3, 2, 2], [5, 3, 4, 4, 5, 3, 3])}, + ('user', 'plays', 'game'): ( + [3, 2, 1, 1, 3, 2, 2], [5, 3, 4, 4, 5, 3, 3])}, idtype=idtype, device=F.ctx()) g.nodes['user'].data['h'] = F.tensor([0, 1, 2, 3, 4]) g.nodes['user'].data['hh'] = F.tensor([0, 1, 2, 3, 4]) g.edges['follow'].data['h'] = F.tensor([0, 1, 2, 3, 4, 5]) - sg, wb = dgl.to_simple(g, return_counts='weights', writeback_mapping=True, copy_edata=True) + sg, wb = dgl.to_simple(g, return_counts='weights', writeback_mapping=True, + copy_edata=True) g.nodes['game'].data['h'] = F.tensor([0, 1, 2, 3, 4, 5]) for etype in g.canonical_etypes: @@ -842,7 +906,8 @@ def test_to_simple(idtype): assert eid_map[i] == suv.index(e) # shared ndata assert F.array_equal(sg.nodes['user'].data['h'], g.nodes['user'].data['h']) - assert F.array_equal(sg.nodes['user'].data['hh'], g.nodes['user'].data['hh']) + assert F.array_equal(sg.nodes['user'].data['hh'], + g.nodes['user'].data['hh']) assert 'h' not in sg.nodes['game'].data # new ndata to sg sg.nodes['user'].data['hhh'] = F.tensor([0, 1, 2, 3, 4]) @@ -869,6 +934,7 @@ def test_to_simple(idtype): sg = dgl.to_simple(g) assert F.array_equal(sg.edge_ids(u, v), eids) + @parametrize_idtype def test_to_block(idtype): def check(g, bg, ntype, etype, dst_nodes, include_dst_in_src=True): @@ -906,7 +972,8 @@ def checkall(g, bg, dst_nodes, include_dst_in_src=True): check(g, bg, ntype, etype, None, include_dst_in_src) # homogeneous graph - g = dgl.graph((F.tensor([1, 2], dtype=idtype), F.tensor([2, 3], dtype=idtype))) + g = dgl.graph( + (F.tensor([1, 2], dtype=idtype), F.tensor([2, 3], dtype=idtype))) dst_nodes = F.tensor([3, 2], dtype=idtype) bg = dgl.to_block(g, dst_nodes=dst_nodes) check(g, bg, '_N', '_E', dst_nodes) @@ -932,17 +999,20 @@ def check_features(g, bg): for key in g.nodes[ntype].data: assert F.array_equal( bg.srcnodes[ntype].data[key], - F.gather_row(g.nodes[ntype].data[key], bg.srcnodes[ntype].data[dgl.NID])) + F.gather_row(g.nodes[ntype].data[key], + bg.srcnodes[ntype].data[dgl.NID])) for ntype in bg.dsttypes: for key in g.nodes[ntype].data: assert F.array_equal( bg.dstnodes[ntype].data[key], - F.gather_row(g.nodes[ntype].data[key], bg.dstnodes[ntype].data[dgl.NID])) + F.gather_row(g.nodes[ntype].data[key], + bg.dstnodes[ntype].data[dgl.NID])) for etype in bg.canonical_etypes: for key in g.edges[etype].data: assert F.array_equal( bg.edges[etype].data[key], - F.gather_row(g.edges[etype].data[key], bg.edges[etype].data[dgl.EID])) + F.gather_row(g.edges[etype].data[key], + bg.edges[etype].data[dgl.EID])) bg = dgl.to_block(g_a) check(g_a, bg, 'A', 'AA', None) @@ -966,7 +1036,8 @@ def check_features(g, bg): bg = dgl.to_block(g_ab) assert bg.idtype == idtype assert bg.number_of_nodes('SRC/B') == 4 - assert F.array_equal(bg.srcnodes['B'].data[dgl.NID], bg.dstnodes['B'].data[dgl.NID]) + assert F.array_equal(bg.srcnodes['B'].data[dgl.NID], + bg.dstnodes['B'].data[dgl.NID]) assert bg.number_of_nodes('DST/A') == 0 checkall(g_ab, bg, None) check_features(g_ab, bg) @@ -974,12 +1045,14 @@ def check_features(g, bg): dst_nodes = {'B': F.tensor([5, 6, 3, 1], dtype=idtype)} bg = dgl.to_block(g, dst_nodes) assert bg.number_of_nodes('SRC/B') == 4 - assert F.array_equal(bg.srcnodes['B'].data[dgl.NID], bg.dstnodes['B'].data[dgl.NID]) + assert F.array_equal(bg.srcnodes['B'].data[dgl.NID], + bg.dstnodes['B'].data[dgl.NID]) assert bg.number_of_nodes('DST/A') == 0 checkall(g, bg, dst_nodes) check_features(g, bg) - dst_nodes = {'A': F.tensor([4, 3, 2, 1], dtype=idtype), 'B': F.tensor([3, 5, 6, 1], dtype=idtype)} + dst_nodes = {'A': F.tensor([4, 3, 2, 1], dtype=idtype), + 'B': F.tensor([3, 5, 6, 1], dtype=idtype)} bg = dgl.to_block(g, dst_nodes=dst_nodes) checkall(g, bg, dst_nodes) check_features(g, bg) @@ -994,7 +1067,8 @@ def check_features(g, bg): check_features(g, bg) # test without include_dst_in_src - dst_nodes = {'A': F.tensor([4, 3, 2, 1], dtype=idtype), 'B': F.tensor([3, 5, 6, 1], dtype=idtype)} + dst_nodes = {'A': F.tensor([4, 3, 2, 1], dtype=idtype), + 'B': F.tensor([3, 5, 6, 1], dtype=idtype)} bg = dgl.to_block(g, dst_nodes=dst_nodes, include_dst_in_src=False) checkall(g, bg, dst_nodes, False) check_features(g, bg) @@ -1005,7 +1079,7 @@ def check_features(g, bg): # use the previous run to get the list of source nodes src_nodes[ntype] = bg.srcnodes[ntype].data[dgl.NID] bg = dgl.to_block(g, dst_nodes=dst_nodes, include_dst_in_src=False, - src_nodes=src_nodes) + src_nodes=src_nodes) checkall(g, bg, dst_nodes, False) check_features(g, bg) @@ -1035,12 +1109,14 @@ def check(g1, etype, g, edges_removed): for fmt in ['coo', 'csr', 'csc']: for edges_to_remove in [[2], [2, 2], [3, 2], [1, 3, 1, 2]]: - g = dgl.graph(([0, 2, 1, 3], [1, 3, 2, 4]), idtype=idtype).formats(fmt) + g = dgl.graph(([0, 2, 1, 3], [1, 3, 2, 4]), idtype=idtype).formats( + fmt) g1 = dgl.remove_edges(g, F.tensor(edges_to_remove, idtype)) check(g1, None, g, edges_to_remove) g = dgl.from_scipy( - spsp.csr_matrix(([1, 1, 1, 1], ([0, 2, 1, 3], [1, 3, 2, 4])), shape=(5, 5)), + spsp.csr_matrix(([1, 1, 1, 1], ([0, 2, 1, 3], [1, 3, 2, 4])), + shape=(5, 5)), idtype=idtype).formats(fmt) g1 = dgl.remove_edges(g, F.tensor(edges_to_remove, idtype)) check(g1, None, g, edges_to_remove) @@ -1049,12 +1125,16 @@ def check(g1, etype, g, edges_removed): ('A', 'AA', 'A'): ([0, 2, 1, 3], [1, 3, 2, 4]), ('A', 'AB', 'B'): ([0, 1, 3, 1], [1, 3, 5, 6]), ('B', 'BA', 'A'): ([2, 3], [3, 2])}, idtype=idtype) - g2 = dgl.remove_edges(g, {'AA': F.tensor([2], idtype), 'AB': F.tensor([3], idtype), 'BA': F.tensor([1], idtype)}) + g2 = dgl.remove_edges(g, {'AA': F.tensor([2], idtype), + 'AB': F.tensor([3], idtype), + 'BA': F.tensor([1], idtype)}) check(g2, 'AA', g, [2]) check(g2, 'AB', g, [3]) check(g2, 'BA', g, [1]) - g3 = dgl.remove_edges(g, {'AA': F.tensor([], idtype), 'AB': F.tensor([3], idtype), 'BA': F.tensor([1], idtype)}) + g3 = dgl.remove_edges(g, {'AA': F.tensor([], idtype), + 'AB': F.tensor([3], idtype), + 'BA': F.tensor([1], idtype)}) check(g3, 'AA', g, []) check(g3, 'AB', g, [3]) check(g3, 'BA', g, [1]) @@ -1064,6 +1144,7 @@ def check(g1, etype, g, edges_removed): check(g4, 'AB', g, [3, 1, 2, 0]) check(g4, 'BA', g, []) + @parametrize_idtype def test_add_edges(idtype): # homogeneous graph @@ -1116,8 +1197,8 @@ def test_add_edges(idtype): g.edata['h'] = F.copy_to(F.tensor([1, 1], dtype=idtype), ctx=F.ctx()) u = F.tensor([0, 1], dtype=idtype) v = F.tensor([2, 3], dtype=idtype) - e_feat = {'h' : F.copy_to(F.tensor([2, 2], dtype=idtype), ctx=F.ctx()), - 'hh' : F.copy_to(F.tensor([2, 2], dtype=idtype), ctx=F.ctx())} + e_feat = {'h': F.copy_to(F.tensor([2, 2], dtype=idtype), ctx=F.ctx()), + 'hh': F.copy_to(F.tensor([2, 2], dtype=idtype), ctx=F.ctx())} g = dgl.add_edges(g, u, v, e_feat) assert g.number_of_nodes() == 4 assert g.number_of_edges() == 4 @@ -1132,8 +1213,8 @@ def test_add_edges(idtype): g = dgl.graph(([], []), num_nodes=0, idtype=idtype, device=F.ctx()) u = F.tensor([0, 1], dtype=idtype) v = F.tensor([2, 2], dtype=idtype) - e_feat = {'h' : F.copy_to(F.tensor([2, 2], dtype=idtype), ctx=F.ctx()), - 'hh' : F.copy_to(F.tensor([2, 2], dtype=idtype), ctx=F.ctx())} + e_feat = {'h': F.copy_to(F.tensor([2, 2], dtype=idtype), ctx=F.ctx()), + 'hh': F.copy_to(F.tensor([2, 2], dtype=idtype), ctx=F.ctx())} g = dgl.add_edges(g, u, v, e_feat) assert g.number_of_nodes() == 3 assert g.number_of_edges() == 2 @@ -1145,7 +1226,8 @@ def test_add_edges(idtype): # bipartite graph g = dgl.heterograph( - {('user', 'plays', 'game'): ([0, 1], [1, 2])}, idtype=idtype, device=F.ctx()) + {('user', 'plays', 'game'): ([0, 1], [1, 2])}, idtype=idtype, + device=F.ctx()) u = 0 v = 1 g = dgl.add_edges(g, u, v) @@ -1173,7 +1255,8 @@ def test_add_edges(idtype): # node id larger than current max node id g = dgl.heterograph( - {('user', 'plays', 'game'): ([0, 1], [1, 2])}, idtype=idtype, device=F.ctx()) + {('user', 'plays', 'game'): ([0, 1], [1, 2])}, idtype=idtype, + device=F.ctx()) u = F.tensor([0, 2], dtype=idtype) v = F.tensor([2, 3], dtype=idtype) g = dgl.add_edges(g, u, v) @@ -1187,14 +1270,17 @@ def test_add_edges(idtype): # has data g = dgl.heterograph( - {('user', 'plays', 'game'): ([0, 1], [1, 2])}, idtype=idtype, device=F.ctx()) - g.nodes['user'].data['h'] = F.copy_to(F.tensor([1, 1], dtype=idtype), ctx=F.ctx()) - g.nodes['game'].data['h'] = F.copy_to(F.tensor([2, 2, 2], dtype=idtype), ctx=F.ctx()) + {('user', 'plays', 'game'): ([0, 1], [1, 2])}, idtype=idtype, + device=F.ctx()) + g.nodes['user'].data['h'] = F.copy_to(F.tensor([1, 1], dtype=idtype), + ctx=F.ctx()) + g.nodes['game'].data['h'] = F.copy_to(F.tensor([2, 2, 2], dtype=idtype), + ctx=F.ctx()) g.edata['h'] = F.copy_to(F.tensor([1, 1], dtype=idtype), ctx=F.ctx()) u = F.tensor([0, 2], dtype=idtype) v = F.tensor([2, 3], dtype=idtype) - e_feat = {'h' : F.copy_to(F.tensor([2, 2], dtype=idtype), ctx=F.ctx()), - 'hh' : F.copy_to(F.tensor([2, 2], dtype=idtype), ctx=F.ctx())} + e_feat = {'h': F.copy_to(F.tensor([2, 2], dtype=idtype), ctx=F.ctx()), + 'hh': F.copy_to(F.tensor([2, 2], dtype=idtype), ctx=F.ctx())} g = dgl.add_edges(g, u, v, e_feat) assert g.number_of_nodes('user') == 3 assert g.number_of_nodes('game') == 4 @@ -1202,8 +1288,10 @@ def test_add_edges(idtype): u, v = g.edges(form='uv', order='eid') assert F.array_equal(u, F.tensor([0, 1, 0, 2], dtype=idtype)) assert F.array_equal(v, F.tensor([1, 2, 2, 3], dtype=idtype)) - assert F.array_equal(g.nodes['user'].data['h'], F.tensor([1, 1, 0], dtype=idtype)) - assert F.array_equal(g.nodes['game'].data['h'], F.tensor([2, 2, 2, 0], dtype=idtype)) + assert F.array_equal(g.nodes['user'].data['h'], + F.tensor([1, 1, 0], dtype=idtype)) + assert F.array_equal(g.nodes['game'].data['h'], + F.tensor([2, 2, 2, 0], dtype=idtype)) assert F.array_equal(g.edata['h'], F.tensor([1, 1, 2, 2], dtype=idtype)) assert F.array_equal(g.edata['hh'], F.tensor([0, 0, 2, 2], dtype=idtype)) @@ -1220,15 +1308,19 @@ def test_add_edges(idtype): u, v = g.edges(form='uv', order='eid', etype='plays') assert F.array_equal(u, F.tensor([0, 1, 1, 2, 0, 2], dtype=idtype)) assert F.array_equal(v, F.tensor([0, 0, 1, 1, 2, 3], dtype=idtype)) - assert F.array_equal(g.nodes['user'].data['h'], F.tensor([1, 1, 1], dtype=idtype)) - assert F.array_equal(g.nodes['game'].data['h'], F.tensor([2, 2, 0, 0], dtype=idtype)) - assert F.array_equal(g.edges['plays'].data['h'], F.tensor([1, 1, 1, 1, 0, 0], dtype=idtype)) + assert F.array_equal(g.nodes['user'].data['h'], + F.tensor([1, 1, 1], dtype=idtype)) + assert F.array_equal(g.nodes['game'].data['h'], + F.tensor([2, 2, 0, 0], dtype=idtype)) + assert F.array_equal(g.edges['plays'].data['h'], + F.tensor([1, 1, 1, 1, 0, 0], dtype=idtype)) # add with feature e_feat = {'h': F.copy_to(F.tensor([2, 2], dtype=idtype), ctx=F.ctx())} u = F.tensor([0, 2], dtype=idtype) v = F.tensor([2, 3], dtype=idtype) - g.nodes['game'].data['h'] = F.copy_to(F.tensor([2, 2, 1, 1], dtype=idtype), ctx=F.ctx()) + g.nodes['game'].data['h'] = F.copy_to(F.tensor([2, 2, 1, 1], dtype=idtype), + ctx=F.ctx()) g = dgl.add_edges(g, u, v, data=e_feat, etype='develops') assert g.number_of_nodes('user') == 3 assert g.number_of_nodes('game') == 4 @@ -1238,15 +1330,19 @@ def test_add_edges(idtype): u, v = g.edges(form='uv', order='eid', etype='develops') assert F.array_equal(u, F.tensor([0, 1, 0, 2], dtype=idtype)) assert F.array_equal(v, F.tensor([0, 1, 2, 3], dtype=idtype)) - assert F.array_equal(g.nodes['developer'].data['h'], F.tensor([3, 3, 0], dtype=idtype)) - assert F.array_equal(g.nodes['game'].data['h'], F.tensor([2, 2, 1, 1], dtype=idtype)) - assert F.array_equal(g.edges['develops'].data['h'], F.tensor([0, 0, 2, 2], dtype=idtype)) + assert F.array_equal(g.nodes['developer'].data['h'], + F.tensor([3, 3, 0], dtype=idtype)) + assert F.array_equal(g.nodes['game'].data['h'], + F.tensor([2, 2, 1, 1], dtype=idtype)) + assert F.array_equal(g.edges['develops'].data['h'], + F.tensor([0, 0, 2, 2], dtype=idtype)) + @parametrize_idtype def test_add_nodes(idtype): # homogeneous Graphs g = dgl.graph(([0, 1], [1, 2]), idtype=idtype, device=F.ctx()) - g.ndata['h'] = F.copy_to(F.tensor([1,1,1], dtype=idtype), ctx=F.ctx()) + g.ndata['h'] = F.copy_to(F.tensor([1, 1, 1], dtype=idtype), ctx=F.ctx()) new_g = dgl.add_nodes(g, 1) assert g.number_of_nodes() == 3 assert new_g.number_of_nodes() == 4 @@ -1254,18 +1350,23 @@ def test_add_nodes(idtype): # zero node graph g = dgl.graph(([], []), num_nodes=3, idtype=idtype, device=F.ctx()) - g.ndata['h'] = F.copy_to(F.tensor([1,1,1], dtype=idtype), ctx=F.ctx()) - g = dgl.add_nodes(g, 1, data={'h' : F.copy_to(F.tensor([2], dtype=idtype), ctx=F.ctx())}) + g.ndata['h'] = F.copy_to(F.tensor([1, 1, 1], dtype=idtype), ctx=F.ctx()) + g = dgl.add_nodes(g, 1, data={ + 'h': F.copy_to(F.tensor([2], dtype=idtype), ctx=F.ctx())}) assert g.number_of_nodes() == 4 assert F.array_equal(g.ndata['h'], F.tensor([1, 1, 1, 2], dtype=idtype)) # bipartite graph g = dgl.heterograph( - {('user', 'plays', 'game'): ([0, 1], [1, 2])}, idtype=idtype, device=F.ctx()) - g = dgl.add_nodes(g, 2, data={'h' : F.copy_to(F.tensor([2, 2], dtype=idtype), ctx=F.ctx())}, ntype='user') + {('user', 'plays', 'game'): ([0, 1], [1, 2])}, idtype=idtype, + device=F.ctx()) + g = dgl.add_nodes(g, 2, data={ + 'h': F.copy_to(F.tensor([2, 2], dtype=idtype), ctx=F.ctx())}, + ntype='user') assert g.number_of_nodes('user') == 4 assert g.number_of_nodes('game') == 3 - assert F.array_equal(g.nodes['user'].data['h'], F.tensor([0, 0, 2, 2], dtype=idtype)) + assert F.array_equal(g.nodes['user'].data['h'], + F.tensor([0, 0, 2, 2], dtype=idtype)) g = dgl.add_nodes(g, 2, ntype='game') assert g.number_of_nodes('user') == 4 assert g.number_of_nodes('game') == 5 @@ -1273,12 +1374,17 @@ def test_add_nodes(idtype): # heterogeneous graph g = create_test_heterograph3(idtype) g = dgl.add_nodes(g, 1, ntype='user') - g = dgl.add_nodes(g, 2, data={'h' : F.copy_to(F.tensor([2, 2], dtype=idtype), ctx=F.ctx())}, ntype='game') + g = dgl.add_nodes(g, 2, data={ + 'h': F.copy_to(F.tensor([2, 2], dtype=idtype), ctx=F.ctx())}, + ntype='game') assert g.number_of_nodes('user') == 4 assert g.number_of_nodes('game') == 4 assert g.number_of_nodes('developer') == 2 - assert F.array_equal(g.nodes['user'].data['h'], F.tensor([1, 1, 1, 0], dtype=idtype)) - assert F.array_equal(g.nodes['game'].data['h'], F.tensor([2, 2, 2, 2], dtype=idtype)) + assert F.array_equal(g.nodes['user'].data['h'], + F.tensor([1, 1, 1, 0], dtype=idtype)) + assert F.array_equal(g.nodes['game'].data['h'], + F.tensor([2, 2, 2, 2], dtype=idtype)) + @parametrize_idtype def test_remove_edges(idtype): @@ -1325,7 +1431,8 @@ def test_remove_edges(idtype): # bipartite graph g = dgl.heterograph( - {('user', 'plays', 'game'): ([0, 1], [1, 2])}, idtype=idtype, device=F.ctx()) + {('user', 'plays', 'game'): ([0, 1], [1, 2])}, idtype=idtype, + device=F.ctx()) e = 0 g = dgl.remove_edges(g, e) assert g.number_of_edges() == 1 @@ -1333,7 +1440,8 @@ def test_remove_edges(idtype): assert F.array_equal(u, F.tensor([1], dtype=idtype)) assert F.array_equal(v, F.tensor([2], dtype=idtype)) g = dgl.heterograph( - {('user', 'plays', 'game'): ([0, 1], [1, 2])}, idtype=idtype, device=F.ctx()) + {('user', 'plays', 'game'): ([0, 1], [1, 2])}, idtype=idtype, + device=F.ctx()) e = [0] g = dgl.remove_edges(g, e) assert g.number_of_edges() == 1 @@ -1346,31 +1454,41 @@ def test_remove_edges(idtype): # has data g = dgl.heterograph( - {('user', 'plays', 'game'): ([0, 1], [1, 2])}, idtype=idtype, device=F.ctx()) - g.nodes['user'].data['h'] = F.copy_to(F.tensor([1, 1], dtype=idtype), ctx=F.ctx()) - g.nodes['game'].data['h'] = F.copy_to(F.tensor([2, 2, 2], dtype=idtype), ctx=F.ctx()) + {('user', 'plays', 'game'): ([0, 1], [1, 2])}, idtype=idtype, + device=F.ctx()) + g.nodes['user'].data['h'] = F.copy_to(F.tensor([1, 1], dtype=idtype), + ctx=F.ctx()) + g.nodes['game'].data['h'] = F.copy_to(F.tensor([2, 2, 2], dtype=idtype), + ctx=F.ctx()) g.edata['h'] = F.copy_to(F.tensor([1, 2], dtype=idtype), ctx=F.ctx()) g = dgl.remove_edges(g, 1) assert g.number_of_edges() == 1 - assert F.array_equal(g.nodes['user'].data['h'], F.tensor([1, 1], dtype=idtype)) - assert F.array_equal(g.nodes['game'].data['h'], F.tensor([2, 2, 2], dtype=idtype)) + assert F.array_equal(g.nodes['user'].data['h'], + F.tensor([1, 1], dtype=idtype)) + assert F.array_equal(g.nodes['game'].data['h'], + F.tensor([2, 2, 2], dtype=idtype)) assert F.array_equal(g.edata['h'], F.tensor([1], dtype=idtype)) # heterogeneous graph g = create_test_heterograph3(idtype) - g.edges['plays'].data['h'] = F.copy_to(F.tensor([1, 2, 3, 4], dtype=idtype), ctx=F.ctx()) + g.edges['plays'].data['h'] = F.copy_to(F.tensor([1, 2, 3, 4], dtype=idtype), + ctx=F.ctx()) g = dgl.remove_edges(g, 1, etype='plays') assert g.number_of_edges('plays') == 3 u, v = g.edges(form='uv', order='eid', etype='plays') assert F.array_equal(u, F.tensor([0, 1, 2], dtype=idtype)) assert F.array_equal(v, F.tensor([0, 1, 1], dtype=idtype)) - assert F.array_equal(g.edges['plays'].data['h'], F.tensor([1, 3, 4], dtype=idtype)) + assert F.array_equal(g.edges['plays'].data['h'], + F.tensor([1, 3, 4], dtype=idtype)) # remove all edges of 'develops' g = dgl.remove_edges(g, [0, 1], etype='develops') assert g.number_of_edges('develops') == 0 - assert F.array_equal(g.nodes['user'].data['h'], F.tensor([1, 1, 1], dtype=idtype)) - assert F.array_equal(g.nodes['game'].data['h'], F.tensor([2, 2], dtype=idtype)) - assert F.array_equal(g.nodes['developer'].data['h'], F.tensor([3, 3], dtype=idtype)) + assert F.array_equal(g.nodes['user'].data['h'], + F.tensor([1, 1, 1], dtype=idtype)) + assert F.array_equal(g.nodes['game'].data['h'], + F.tensor([2, 2], dtype=idtype)) + assert F.array_equal(g.nodes['developer'].data['h'], + F.tensor([3, 3], dtype=idtype)) # batched graph ctx = F.ctx() @@ -1381,17 +1499,20 @@ def test_remove_edges(idtype): bg_r = dgl.remove_edges(bg, 2) assert bg.batch_size == bg_r.batch_size assert F.array_equal(bg.batch_num_nodes(), bg_r.batch_num_nodes()) - assert F.array_equal(bg_r.batch_num_edges(), F.tensor([2, 0, 2], dtype=F.int64)) + assert F.array_equal(bg_r.batch_num_edges(), + F.tensor([2, 0, 2], dtype=F.int64)) bg_r = dgl.remove_edges(bg, [0, 2]) assert bg.batch_size == bg_r.batch_size assert F.array_equal(bg.batch_num_nodes(), bg_r.batch_num_nodes()) - assert F.array_equal(bg_r.batch_num_edges(), F.tensor([1, 0, 2], dtype=F.int64)) + assert F.array_equal(bg_r.batch_num_edges(), + F.tensor([1, 0, 2], dtype=F.int64)) bg_r = dgl.remove_edges(bg, F.tensor([0, 2], dtype=idtype)) assert bg.batch_size == bg_r.batch_size assert F.array_equal(bg.batch_num_nodes(), bg_r.batch_num_nodes()) - assert F.array_equal(bg_r.batch_num_edges(), F.tensor([1, 0, 2], dtype=F.int64)) + assert F.array_equal(bg_r.batch_num_edges(), + F.tensor([1, 0, 2], dtype=F.int64)) # batched heterogeneous graph g1 = dgl.heterograph({ @@ -1412,43 +1533,57 @@ def test_remove_edges(idtype): ntypes = bg.ntypes for nty in ntypes: assert F.array_equal(bg.batch_num_nodes(nty), bg_r.batch_num_nodes(nty)) - assert F.array_equal(bg_r.batch_num_edges('follows'), F.tensor([1, 2, 0], dtype=F.int64)) - assert F.array_equal(bg_r.batch_num_edges('plays'), bg.batch_num_edges('plays')) + assert F.array_equal(bg_r.batch_num_edges('follows'), + F.tensor([1, 2, 0], dtype=F.int64)) + assert F.array_equal(bg_r.batch_num_edges('plays'), + bg.batch_num_edges('plays')) bg_r = dgl.remove_edges(bg, 2, etype='plays') assert bg.batch_size == bg_r.batch_size for nty in ntypes: assert F.array_equal(bg.batch_num_nodes(nty), bg_r.batch_num_nodes(nty)) - assert F.array_equal(bg.batch_num_edges('follows'), bg_r.batch_num_edges('follows')) - assert F.array_equal(bg_r.batch_num_edges('plays'), F.tensor([2, 0, 1], dtype=F.int64)) + assert F.array_equal(bg.batch_num_edges('follows'), + bg_r.batch_num_edges('follows')) + assert F.array_equal(bg_r.batch_num_edges('plays'), + F.tensor([2, 0, 1], dtype=F.int64)) bg_r = dgl.remove_edges(bg, [0, 1, 3], etype='follows') assert bg.batch_size == bg_r.batch_size for nty in ntypes: assert F.array_equal(bg.batch_num_nodes(nty), bg_r.batch_num_nodes(nty)) - assert F.array_equal(bg_r.batch_num_edges('follows'), F.tensor([0, 1, 0], dtype=F.int64)) - assert F.array_equal(bg.batch_num_edges('plays'), bg_r.batch_num_edges('plays')) + assert F.array_equal(bg_r.batch_num_edges('follows'), + F.tensor([0, 1, 0], dtype=F.int64)) + assert F.array_equal(bg.batch_num_edges('plays'), + bg_r.batch_num_edges('plays')) bg_r = dgl.remove_edges(bg, [1, 2], etype='plays') assert bg.batch_size == bg_r.batch_size for nty in ntypes: assert F.array_equal(bg.batch_num_nodes(nty), bg_r.batch_num_nodes(nty)) - assert F.array_equal(bg.batch_num_edges('follows'), bg_r.batch_num_edges('follows')) - assert F.array_equal(bg_r.batch_num_edges('plays'), F.tensor([1, 0, 1], dtype=F.int64)) + assert F.array_equal(bg.batch_num_edges('follows'), + bg_r.batch_num_edges('follows')) + assert F.array_equal(bg_r.batch_num_edges('plays'), + F.tensor([1, 0, 1], dtype=F.int64)) - bg_r = dgl.remove_edges(bg, F.tensor([0, 1, 3], dtype=idtype), etype='follows') + bg_r = dgl.remove_edges(bg, F.tensor([0, 1, 3], dtype=idtype), + etype='follows') assert bg.batch_size == bg_r.batch_size for nty in ntypes: assert F.array_equal(bg.batch_num_nodes(nty), bg_r.batch_num_nodes(nty)) - assert F.array_equal(bg_r.batch_num_edges('follows'), F.tensor([0, 1, 0], dtype=F.int64)) - assert F.array_equal(bg.batch_num_edges('plays'), bg_r.batch_num_edges('plays')) + assert F.array_equal(bg_r.batch_num_edges('follows'), + F.tensor([0, 1, 0], dtype=F.int64)) + assert F.array_equal(bg.batch_num_edges('plays'), + bg_r.batch_num_edges('plays')) bg_r = dgl.remove_edges(bg, F.tensor([1, 2], dtype=idtype), etype='plays') assert bg.batch_size == bg_r.batch_size for nty in ntypes: assert F.array_equal(bg.batch_num_nodes(nty), bg_r.batch_num_nodes(nty)) - assert F.array_equal(bg.batch_num_edges('follows'), bg_r.batch_num_edges('follows')) - assert F.array_equal(bg_r.batch_num_edges('plays'), F.tensor([1, 0, 1], dtype=F.int64)) + assert F.array_equal(bg.batch_num_edges('follows'), + bg_r.batch_num_edges('follows')) + assert F.array_equal(bg_r.batch_num_edges('plays'), + F.tensor([1, 0, 1], dtype=F.int64)) + @parametrize_idtype def test_remove_nodes(idtype): @@ -1498,7 +1633,8 @@ def test_remove_nodes(idtype): # node id larger than current max node id g = dgl.heterograph( - {('user', 'plays', 'game'): ([0, 1], [1, 2])}, idtype=idtype, device=F.ctx()) + {('user', 'plays', 'game'): ([0, 1], [1, 2])}, idtype=idtype, + device=F.ctx()) n = 0 g = dgl.remove_nodes(g, n, ntype='user') assert g.number_of_nodes('user') == 1 @@ -1508,7 +1644,8 @@ def test_remove_nodes(idtype): assert F.array_equal(u, F.tensor([0], dtype=idtype)) assert F.array_equal(v, F.tensor([2], dtype=idtype)) g = dgl.heterograph( - {('user', 'plays', 'game'): ([0, 1], [1, 2])}, idtype=idtype, device=F.ctx()) + {('user', 'plays', 'game'): ([0, 1], [1, 2])}, idtype=idtype, + device=F.ctx()) n = [1] g = dgl.remove_nodes(g, n, ntype='user') assert g.number_of_nodes('user') == 1 @@ -1518,7 +1655,8 @@ def test_remove_nodes(idtype): assert F.array_equal(u, F.tensor([0], dtype=idtype)) assert F.array_equal(v, F.tensor([1], dtype=idtype)) g = dgl.heterograph( - {('user', 'plays', 'game'): ([0, 1], [1, 2])}, idtype=idtype, device=F.ctx()) + {('user', 'plays', 'game'): ([0, 1], [1, 2])}, idtype=idtype, + device=F.ctx()) n = F.tensor([0], dtype=idtype) g = dgl.remove_nodes(g, n, ntype='game') assert g.number_of_nodes('user') == 2 @@ -1526,24 +1664,28 @@ def test_remove_nodes(idtype): assert g.number_of_edges() == 2 u, v = g.edges(form='uv', order='eid') assert F.array_equal(u, F.tensor([0, 1], dtype=idtype)) - assert F.array_equal(v, F.tensor([0 ,1], dtype=idtype)) + assert F.array_equal(v, F.tensor([0, 1], dtype=idtype)) # heterogeneous graph g = create_test_heterograph3(idtype) - g.edges['plays'].data['h'] = F.copy_to(F.tensor([1, 2, 3, 4], dtype=idtype), ctx=F.ctx()) + g.edges['plays'].data['h'] = F.copy_to(F.tensor([1, 2, 3, 4], dtype=idtype), + ctx=F.ctx()) g = dgl.remove_nodes(g, 0, ntype='game') assert g.number_of_nodes('user') == 3 assert g.number_of_nodes('game') == 1 assert g.number_of_nodes('developer') == 2 assert g.number_of_edges('plays') == 2 assert g.number_of_edges('develops') == 1 - assert F.array_equal(g.nodes['user'].data['h'], F.tensor([1, 1, 1], dtype=idtype)) + assert F.array_equal(g.nodes['user'].data['h'], + F.tensor([1, 1, 1], dtype=idtype)) assert F.array_equal(g.nodes['game'].data['h'], F.tensor([2], dtype=idtype)) - assert F.array_equal(g.nodes['developer'].data['h'], F.tensor([3, 3], dtype=idtype)) + assert F.array_equal(g.nodes['developer'].data['h'], + F.tensor([3, 3], dtype=idtype)) u, v = g.edges(form='uv', order='eid', etype='plays') assert F.array_equal(u, F.tensor([1, 2], dtype=idtype)) assert F.array_equal(v, F.tensor([0, 0], dtype=idtype)) - assert F.array_equal(g.edges['plays'].data['h'], F.tensor([3, 4], dtype=idtype)) + assert F.array_equal(g.edges['plays'].data['h'], + F.tensor([3, 4], dtype=idtype)) u, v = g.edges(form='uv', order='eid', etype='develops') assert F.array_equal(u, F.tensor([1], dtype=idtype)) assert F.array_equal(v, F.tensor([0], dtype=idtype)) @@ -1556,18 +1698,24 @@ def test_remove_nodes(idtype): bg = dgl.batch([g1, g2, g3]) bg_r = dgl.remove_nodes(bg, 1) assert bg_r.batch_size == bg.batch_size - assert F.array_equal(bg_r.batch_num_nodes(), F.tensor([4, 0, 5], dtype=F.int64)) - assert F.array_equal(bg_r.batch_num_edges(), F.tensor([0, 0, 3], dtype=F.int64)) + assert F.array_equal(bg_r.batch_num_nodes(), + F.tensor([4, 0, 5], dtype=F.int64)) + assert F.array_equal(bg_r.batch_num_edges(), + F.tensor([0, 0, 3], dtype=F.int64)) bg_r = dgl.remove_nodes(bg, [1, 7]) assert bg_r.batch_size == bg.batch_size - assert F.array_equal(bg_r.batch_num_nodes(), F.tensor([4, 0, 4], dtype=F.int64)) - assert F.array_equal(bg_r.batch_num_edges(), F.tensor([0, 0, 1], dtype=F.int64)) + assert F.array_equal(bg_r.batch_num_nodes(), + F.tensor([4, 0, 4], dtype=F.int64)) + assert F.array_equal(bg_r.batch_num_edges(), + F.tensor([0, 0, 1], dtype=F.int64)) bg_r = dgl.remove_nodes(bg, F.tensor([1, 7], dtype=idtype)) assert bg_r.batch_size == bg.batch_size - assert F.array_equal(bg_r.batch_num_nodes(), F.tensor([4, 0, 4], dtype=F.int64)) - assert F.array_equal(bg_r.batch_num_edges(), F.tensor([0, 0, 1], dtype=F.int64)) + assert F.array_equal(bg_r.batch_num_nodes(), + F.tensor([4, 0, 4], dtype=F.int64)) + assert F.array_equal(bg_r.batch_num_edges(), + F.tensor([0, 0, 1], dtype=F.int64)) # batched heterogeneous graph g1 = dgl.heterograph({ @@ -1585,45 +1733,72 @@ def test_remove_nodes(idtype): bg = dgl.batch([g1, g2, g3]) bg_r = dgl.remove_nodes(bg, 1, ntype='user') assert bg_r.batch_size == bg.batch_size - assert F.array_equal(bg_r.batch_num_nodes('user'), F.tensor([3, 6, 3], dtype=F.int64)) - assert F.array_equal(bg.batch_num_nodes('game'), bg_r.batch_num_nodes('game')) - assert F.array_equal(bg_r.batch_num_edges('follows'), F.tensor([0, 2, 0], dtype=F.int64)) - assert F.array_equal(bg_r.batch_num_edges('plays'), F.tensor([1, 0, 2], dtype=F.int64)) + assert F.array_equal(bg_r.batch_num_nodes('user'), + F.tensor([3, 6, 3], dtype=F.int64)) + assert F.array_equal(bg.batch_num_nodes('game'), + bg_r.batch_num_nodes('game')) + assert F.array_equal(bg_r.batch_num_edges('follows'), + F.tensor([0, 2, 0], dtype=F.int64)) + assert F.array_equal(bg_r.batch_num_edges('plays'), + F.tensor([1, 0, 2], dtype=F.int64)) bg_r = dgl.remove_nodes(bg, 6, ntype='game') assert bg_r.batch_size == bg.batch_size - assert F.array_equal(bg.batch_num_nodes('user'), bg_r.batch_num_nodes('user')) - assert F.array_equal(bg_r.batch_num_nodes('game'), F.tensor([3, 2, 2], dtype=F.int64)) - assert F.array_equal(bg.batch_num_edges('follows'), bg_r.batch_num_edges('follows')) - assert F.array_equal(bg_r.batch_num_edges('plays'), F.tensor([2, 0, 1], dtype=F.int64)) + assert F.array_equal(bg.batch_num_nodes('user'), + bg_r.batch_num_nodes('user')) + assert F.array_equal(bg_r.batch_num_nodes('game'), + F.tensor([3, 2, 2], dtype=F.int64)) + assert F.array_equal(bg.batch_num_edges('follows'), + bg_r.batch_num_edges('follows')) + assert F.array_equal(bg_r.batch_num_edges('plays'), + F.tensor([2, 0, 1], dtype=F.int64)) bg_r = dgl.remove_nodes(bg, [1, 5, 6, 11], ntype='user') assert bg_r.batch_size == bg.batch_size - assert F.array_equal(bg_r.batch_num_nodes('user'), F.tensor([3, 4, 2], dtype=F.int64)) - assert F.array_equal(bg.batch_num_nodes('game'), bg_r.batch_num_nodes('game')) - assert F.array_equal(bg_r.batch_num_edges('follows'), F.tensor([0, 1, 0], dtype=F.int64)) - assert F.array_equal(bg_r.batch_num_edges('plays'), F.tensor([1, 0, 1], dtype=F.int64)) + assert F.array_equal(bg_r.batch_num_nodes('user'), + F.tensor([3, 4, 2], dtype=F.int64)) + assert F.array_equal(bg.batch_num_nodes('game'), + bg_r.batch_num_nodes('game')) + assert F.array_equal(bg_r.batch_num_edges('follows'), + F.tensor([0, 1, 0], dtype=F.int64)) + assert F.array_equal(bg_r.batch_num_edges('plays'), + F.tensor([1, 0, 1], dtype=F.int64)) bg_r = dgl.remove_nodes(bg, [0, 3, 4, 7], ntype='game') assert bg_r.batch_size == bg.batch_size - assert F.array_equal(bg.batch_num_nodes('user'), bg_r.batch_num_nodes('user')) - assert F.array_equal(bg_r.batch_num_nodes('game'), F.tensor([2, 0, 2], dtype=F.int64)) - assert F.array_equal(bg.batch_num_edges('follows'), bg_r.batch_num_edges('follows')) - assert F.array_equal(bg_r.batch_num_edges('plays'), F.tensor([1, 0, 1], dtype=F.int64)) - - bg_r = dgl.remove_nodes(bg, F.tensor([1, 5, 6, 11], dtype=idtype), ntype='user') + assert F.array_equal(bg.batch_num_nodes('user'), + bg_r.batch_num_nodes('user')) + assert F.array_equal(bg_r.batch_num_nodes('game'), + F.tensor([2, 0, 2], dtype=F.int64)) + assert F.array_equal(bg.batch_num_edges('follows'), + bg_r.batch_num_edges('follows')) + assert F.array_equal(bg_r.batch_num_edges('plays'), + F.tensor([1, 0, 1], dtype=F.int64)) + + bg_r = dgl.remove_nodes(bg, F.tensor([1, 5, 6, 11], dtype=idtype), + ntype='user') assert bg_r.batch_size == bg.batch_size - assert F.array_equal(bg_r.batch_num_nodes('user'), F.tensor([3, 4, 2], dtype=F.int64)) - assert F.array_equal(bg.batch_num_nodes('game'), bg_r.batch_num_nodes('game')) - assert F.array_equal(bg_r.batch_num_edges('follows'), F.tensor([0, 1, 0], dtype=F.int64)) - assert F.array_equal(bg_r.batch_num_edges('plays'), F.tensor([1, 0, 1], dtype=F.int64)) - - bg_r = dgl.remove_nodes(bg, F.tensor([0, 3, 4, 7], dtype=idtype), ntype='game') + assert F.array_equal(bg_r.batch_num_nodes('user'), + F.tensor([3, 4, 2], dtype=F.int64)) + assert F.array_equal(bg.batch_num_nodes('game'), + bg_r.batch_num_nodes('game')) + assert F.array_equal(bg_r.batch_num_edges('follows'), + F.tensor([0, 1, 0], dtype=F.int64)) + assert F.array_equal(bg_r.batch_num_edges('plays'), + F.tensor([1, 0, 1], dtype=F.int64)) + + bg_r = dgl.remove_nodes(bg, F.tensor([0, 3, 4, 7], dtype=idtype), + ntype='game') assert bg_r.batch_size == bg.batch_size - assert F.array_equal(bg.batch_num_nodes('user'), bg_r.batch_num_nodes('user')) - assert F.array_equal(bg_r.batch_num_nodes('game'), F.tensor([2, 0, 2], dtype=F.int64)) - assert F.array_equal(bg.batch_num_edges('follows'), bg_r.batch_num_edges('follows')) - assert F.array_equal(bg_r.batch_num_edges('plays'), F.tensor([1, 0, 1], dtype=F.int64)) + assert F.array_equal(bg.batch_num_nodes('user'), + bg_r.batch_num_nodes('user')) + assert F.array_equal(bg_r.batch_num_nodes('game'), + F.tensor([2, 0, 2], dtype=F.int64)) + assert F.array_equal(bg.batch_num_edges('follows'), + bg_r.batch_num_edges('follows')) + assert F.array_equal(bg_r.batch_num_edges('plays'), + F.tensor([1, 0, 1], dtype=F.int64)) + @parametrize_idtype def test_add_selfloop(idtype): @@ -1632,7 +1807,8 @@ def test_add_selfloop(idtype): # test for fill_data is float g = dgl.graph(([0, 0, 2], [2, 1, 0]), idtype=idtype, device=F.ctx()) g.edata['he'] = F.copy_to(F.tensor([1, 2, 3], dtype=idtype), ctx=F.ctx()) - g.edata['he1'] = F.copy_to(F.tensor([[0., 1.], [2., 3.], [4., 5.]]), ctx=F.ctx()) + g.edata['he1'] = F.copy_to(F.tensor([[0., 1.], [2., 3.], [4., 5.]]), + ctx=F.ctx()) g.ndata['hn'] = F.copy_to(F.tensor([1, 2, 3], dtype=idtype), ctx=F.ctx()) g = dgl.add_self_loop(g) assert g.number_of_nodes() == 3 @@ -1640,14 +1816,17 @@ def test_add_selfloop(idtype): u, v = g.edges(form='uv', order='eid') assert F.array_equal(u, F.tensor([0, 0, 2, 0, 1, 2], dtype=idtype)) assert F.array_equal(v, F.tensor([2, 1, 0, 0, 1, 2], dtype=idtype)) - assert F.array_equal(g.edata['he'], F.tensor([1, 2, 3, 1, 1, 1], dtype=idtype)) + assert F.array_equal(g.edata['he'], + F.tensor([1, 2, 3, 1, 1, 1], dtype=idtype)) assert F.array_equal(g.edata['he1'], F.tensor([[0., 1.], [2., 3.], [4., 5.], - [1., 1.], [1., 1.], [1., 1.]])) + [1., 1.], [1., 1.], + [1., 1.]])) # test for fill_data is int g = dgl.graph(([0, 0, 2], [2, 1, 0]), idtype=idtype, device=F.ctx()) g.edata['he'] = F.copy_to(F.tensor([1, 2, 3], dtype=idtype), ctx=F.ctx()) - g.edata['he1'] = F.copy_to(F.tensor([[0, 1], [2, 3], [4, 5]], dtype=idtype), ctx=F.ctx()) + g.edata['he1'] = F.copy_to(F.tensor([[0, 1], [2, 3], [4, 5]], dtype=idtype), + ctx=F.ctx()) g.ndata['hn'] = F.copy_to(F.tensor([1, 2, 3], dtype=idtype), ctx=F.ctx()) g = dgl.add_self_loop(g, fill_data=1) assert g.number_of_nodes() == 3 @@ -1655,14 +1834,17 @@ def test_add_selfloop(idtype): u, v = g.edges(form='uv', order='eid') assert F.array_equal(u, F.tensor([0, 0, 2, 0, 1, 2], dtype=idtype)) assert F.array_equal(v, F.tensor([2, 1, 0, 0, 1, 2], dtype=idtype)) - assert F.array_equal(g.edata['he'], F.tensor([1, 2, 3, 1, 1, 1], dtype=idtype)) + assert F.array_equal(g.edata['he'], + F.tensor([1, 2, 3, 1, 1, 1], dtype=idtype)) assert F.array_equal(g.edata['he1'], F.tensor([[0, 1], [2, 3], [4, 5], - [1, 1], [1, 1], [1, 1]], dtype=idtype)) + [1, 1], [1, 1], [1, 1]], + dtype=idtype)) # test for fill_data is str g = dgl.graph(([0, 0, 2], [2, 1, 0]), idtype=idtype, device=F.ctx()) g.edata['he'] = F.copy_to(F.tensor([1., 2., 3.]), ctx=F.ctx()) - g.edata['he1'] = F.copy_to(F.tensor([[0., 1.], [2., 3.], [4., 5.]]), ctx=F.ctx()) + g.edata['he1'] = F.copy_to(F.tensor([[0., 1.], [2., 3.], [4., 5.]]), + ctx=F.ctx()) g.ndata['hn'] = F.copy_to(F.tensor([1, 2, 3], dtype=idtype), ctx=F.ctx()) g = dgl.add_self_loop(g, fill_data='sum') assert g.number_of_nodes() == 3 @@ -1672,11 +1854,13 @@ def test_add_selfloop(idtype): assert F.array_equal(v, F.tensor([2, 1, 0, 0, 1, 2], dtype=idtype)) assert F.array_equal(g.edata['he'], F.tensor([1., 2., 3., 3., 2., 1.])) assert F.array_equal(g.edata['he1'], F.tensor([[0., 1.], [2., 3.], [4., 5.], - [4., 5.], [2., 3.], [0., 1.]])) + [4., 5.], [2., 3.], + [0., 1.]])) # bipartite graph g = dgl.heterograph( - {('user', 'plays', 'game'): ([0, 1, 2], [1, 2, 2])}, idtype=idtype, device=F.ctx()) + {('user', 'plays', 'game'): ([0, 1, 2], [1, 2, 2])}, idtype=idtype, + device=F.ctx()) # nothing will happend raise_error = False try: @@ -1687,7 +1871,8 @@ def test_add_selfloop(idtype): # test for fill_data is float g = create_test_heterograph5(idtype) - g.edges['follows'].data['h1'] = F.copy_to(F.tensor([[0., 1.], [1., 2.]]), ctx=F.ctx()) + g.edges['follows'].data['h1'] = F.copy_to(F.tensor([[0., 1.], [1., 2.]]), + ctx=F.ctx()) g = dgl.add_self_loop(g, etype='follows') assert g.number_of_nodes('user') == 3 assert g.number_of_nodes('game') == 2 @@ -1696,14 +1881,18 @@ def test_add_selfloop(idtype): u, v = g.edges(form='uv', order='eid', etype='follows') assert F.array_equal(u, F.tensor([1, 2, 0, 1, 2], dtype=idtype)) assert F.array_equal(v, F.tensor([0, 1, 0, 1, 2], dtype=idtype)) - assert F.array_equal(g.edges['follows'].data['h'], F.tensor([1, 2, 1, 1, 1], dtype=idtype)) - assert F.array_equal(g.edges['follows'].data['h1'], F.tensor([[0., 1.], [1., 2.], [1., 1.], - [1., 1.], [1., 1.]])) - assert F.array_equal(g.edges['plays'].data['h'], F.tensor([1, 2], dtype=idtype)) + assert F.array_equal(g.edges['follows'].data['h'], + F.tensor([1, 2, 1, 1, 1], dtype=idtype)) + assert F.array_equal(g.edges['follows'].data['h1'], + F.tensor([[0., 1.], [1., 2.], [1., 1.], + [1., 1.], [1., 1.]])) + assert F.array_equal(g.edges['plays'].data['h'], + F.tensor([1, 2], dtype=idtype)) # test for fill_data is int g = create_test_heterograph5(idtype) - g.edges['follows'].data['h1'] = F.copy_to(F.tensor([[0, 1], [1, 2]], dtype=idtype), ctx=F.ctx()) + g.edges['follows'].data['h1'] = F.copy_to( + F.tensor([[0, 1], [1, 2]], dtype=idtype), ctx=F.ctx()) g = dgl.add_self_loop(g, fill_data=1, etype='follows') assert g.number_of_nodes('user') == 3 assert g.number_of_nodes('game') == 2 @@ -1712,10 +1901,13 @@ def test_add_selfloop(idtype): u, v = g.edges(form='uv', order='eid', etype='follows') assert F.array_equal(u, F.tensor([1, 2, 0, 1, 2], dtype=idtype)) assert F.array_equal(v, F.tensor([0, 1, 0, 1, 2], dtype=idtype)) - assert F.array_equal(g.edges['follows'].data['h'], F.tensor([1, 2, 1, 1, 1], dtype=idtype)) - assert F.array_equal(g.edges['follows'].data['h1'], F.tensor([[0, 1], [1, 2], [1, 1], - [1, 1], [1, 1]], dtype=idtype)) - assert F.array_equal(g.edges['plays'].data['h'], F.tensor([1, 2], dtype=idtype)) + assert F.array_equal(g.edges['follows'].data['h'], + F.tensor([1, 2, 1, 1, 1], dtype=idtype)) + assert F.array_equal(g.edges['follows'].data['h1'], + F.tensor([[0, 1], [1, 2], [1, 1], + [1, 1], [1, 1]], dtype=idtype)) + assert F.array_equal(g.edges['plays'].data['h'], + F.tensor([1, 2], dtype=idtype)) # test for fill_data is str g = dgl.heterograph({ @@ -1724,10 +1916,13 @@ def test_add_selfloop(idtype): ('user', 'plays', 'game'): (F.tensor([0, 1], dtype=idtype), F.tensor([0, 1], dtype=idtype))}, idtype=idtype, device=F.ctx()) - g.nodes['user'].data['h'] = F.copy_to(F.tensor([1, 1, 1], dtype=idtype), ctx=F.ctx()) - g.nodes['game'].data['h'] = F.copy_to(F.tensor([2, 2], dtype=idtype), ctx=F.ctx()) + g.nodes['user'].data['h'] = F.copy_to(F.tensor([1, 1, 1], dtype=idtype), + ctx=F.ctx()) + g.nodes['game'].data['h'] = F.copy_to(F.tensor([2, 2], dtype=idtype), + ctx=F.ctx()) g.edges['follows'].data['h'] = F.copy_to(F.tensor([1., 2.]), ctx=F.ctx()) - g.edges['follows'].data['h1'] = F.copy_to(F.tensor([[0., 1.], [1., 2.]]), ctx=F.ctx()) + g.edges['follows'].data['h1'] = F.copy_to(F.tensor([[0., 1.], [1., 2.]]), + ctx=F.ctx()) g.edges['plays'].data['h'] = F.copy_to(F.tensor([1., 2.]), ctx=F.ctx()) g = dgl.add_self_loop(g, fill_data='mean', etype='follows') assert g.number_of_nodes('user') == 3 @@ -1737,9 +1932,11 @@ def test_add_selfloop(idtype): u, v = g.edges(form='uv', order='eid', etype='follows') assert F.array_equal(u, F.tensor([1, 2, 0, 1, 2], dtype=idtype)) assert F.array_equal(v, F.tensor([0, 1, 0, 1, 2], dtype=idtype)) - assert F.array_equal(g.edges['follows'].data['h'], F.tensor([1., 2., 1., 2., 0.])) - assert F.array_equal(g.edges['follows'].data['h1'], F.tensor([[0., 1.], [1., 2.], [0., 1.], - [1., 2.], [0., 0.]])) + assert F.array_equal(g.edges['follows'].data['h'], + F.tensor([1., 2., 1., 2., 0.])) + assert F.array_equal(g.edges['follows'].data['h1'], + F.tensor([[0., 1.], [1., 2.], [0., 1.], + [1., 2.], [0., 0.]])) assert F.array_equal(g.edges['plays'].data['h'], F.tensor([1., 2.])) raise_error = False @@ -1749,6 +1946,7 @@ def test_add_selfloop(idtype): raise_error = True assert raise_error + @parametrize_idtype def test_remove_selfloop(idtype): # homogeneous graph @@ -1761,7 +1959,8 @@ def test_remove_selfloop(idtype): # bipartite graph g = dgl.heterograph( - {('user', 'plays', 'game'): ([0, 1, 2], [1, 2, 2])}, idtype=idtype, device=F.ctx()) + {('user', 'plays', 'game'): ([0, 1, 2], [1, 2, 2])}, idtype=idtype, + device=F.ctx()) # nothing will happend raise_error = False try: @@ -1779,8 +1978,10 @@ def test_remove_selfloop(idtype): u, v = g.edges(form='uv', order='eid', etype='follows') assert F.array_equal(u, F.tensor([1, 2], dtype=idtype)) assert F.array_equal(v, F.tensor([0, 1], dtype=idtype)) - assert F.array_equal(g.edges['follows'].data['h'], F.tensor([2, 4], dtype=idtype)) - assert F.array_equal(g.edges['plays'].data['h'], F.tensor([1, 2], dtype=idtype)) + assert F.array_equal(g.edges['follows'].data['h'], + F.tensor([2, 4], dtype=idtype)) + assert F.array_equal(g.edges['plays'].data['h'], + F.tensor([1, 2], dtype=idtype)) raise_error = False try: @@ -1790,7 +1991,8 @@ def test_remove_selfloop(idtype): assert raise_error # batch information - g = dgl.graph(([0, 0, 0, 1, 3, 3, 4], [1, 0, 0, 2, 3, 4, 4]), idtype=idtype, device=F.ctx()) + g = dgl.graph(([0, 0, 0, 1, 3, 3, 4], [1, 0, 0, 2, 3, 4, 4]), idtype=idtype, + device=F.ctx()) g.set_batch_num_nodes(F.tensor([3, 2], dtype=F.int64)) g.set_batch_num_edges(F.tensor([4, 3], dtype=F.int64)) g = dgl.remove_self_loop(g) @@ -1864,7 +2066,8 @@ def test_reorder_graph(idtype): raise_error = False try: dgl.reorder_graph(mg, - node_permute_algo='metis', permute_config={'k': 2}) + node_permute_algo='metis', + permute_config={'k': 2}) except: raise_error = True assert not raise_error @@ -1883,7 +2086,7 @@ def test_reorder_graph(idtype): raise_error = False try: dgl.reorder_graph(g, node_permute_algo='custom', permute_config={ - 'nodes_perm': nodes_perm[:g.num_nodes() - 1]}) + 'nodes_perm': nodes_perm[:g.num_nodes() - 1]}) except: raise_error = True assert raise_error @@ -1908,12 +2111,14 @@ def test_reorder_graph(idtype): # TODO: shall we fix them? # add 'csc' format if needed - #fg = g.formats('csr') - #assert 'csc' not in sum(fg.formats().values(), []) - #rfg = dgl.reorder_graph(fg) - #assert 'csc' in sum(rfg.formats().values(), []) + # fg = g.formats('csr') + # assert 'csc' not in sum(fg.formats().values(), []) + # rfg = dgl.reorder_graph(fg) + # assert 'csc' in sum(rfg.formats().values(), []) -@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support a slicing operation") + +@unittest.skipIf(dgl.backend.backend_name == "tensorflow", + reason="TF doesn't support a slicing operation") @parametrize_idtype def test_norm_by_dst(idtype): # Case1: A homogeneous graph @@ -1929,6 +2134,7 @@ def test_norm_by_dst(idtype): eweight = dgl.norm_by_dst(g, etype=('user', 'plays', 'game')) assert F.allclose(eweight, F.tensor([0.5, 0.5, 1.0])) + @parametrize_idtype def test_module_add_self_loop(idtype): g = dgl.graph(([1, 1], [1, 2]), idtype=idtype, device=F.ctx()) @@ -2007,6 +2213,7 @@ def test_module_add_self_loop(idtype): assert 'w1' in new_g.edges['plays'].data assert 'w2' in new_g.edges['follows'].data + @parametrize_idtype def test_module_remove_self_loop(idtype): transform = dgl.RemoveSelfLoop() @@ -2050,6 +2257,7 @@ def test_module_remove_self_loop(idtype): assert 'w1' in new_g.edges['plays'].data assert 'w2' in new_g.edges['follows'].data + @parametrize_idtype def test_module_add_reverse(idtype): transform = dgl.AddReverse() @@ -2067,7 +2275,8 @@ def test_module_add_reverse(idtype): assert eset == {(0, 1), (1, 0)} assert F.allclose(g.ndata['h'], new_g.ndata['h']) assert F.allclose(g.edata['w'], F.narrow_row(new_g.edata['w'], 0, 1)) - assert F.allclose(F.narrow_row(new_g.edata['w'], 1, 2), F.zeros((1, 2), F.float32, F.ctx())) + assert F.allclose(F.narrow_row(new_g.edata['w'], 1, 2), + F.zeros((1, 2), F.float32, F.ctx())) # Case2: Add reverse edges for a homogeneous graph and copy edata transform = dgl.AddReverse(copy_edata=True) @@ -2092,7 +2301,8 @@ def test_module_add_reverse(idtype): assert new_g.idtype == g.idtype assert g.ntypes == new_g.ntypes assert set(new_g.canonical_etypes) == { - ('user', 'plays', 'game'), ('user', 'follows', 'user'), ('game', 'rev_plays', 'user')} + ('user', 'plays', 'game'), ('user', 'follows', 'user'), + ('game', 'rev_plays', 'user')} for nty in g.ntypes: assert g.num_nodes(nty) == new_g.num_nodes(nty) @@ -2136,7 +2346,9 @@ def test_module_add_reverse(idtype): eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst)))) assert eset == {(2, 1), (2, 2)} -@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU not supported for to_simple") + +@unittest.skipIf(F._default_context_str == 'gpu', + reason="GPU not supported for to_simple") @parametrize_idtype def test_module_to_simple(idtype): transform = dgl.ToSimple() @@ -2176,6 +2388,7 @@ def test_module_to_simple(idtype): eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst)))) assert eset == {(0, 1), (1, 1)} + @parametrize_idtype def test_module_line_graph(idtype): transform = dgl.LineGraph() @@ -2199,6 +2412,7 @@ def test_module_line_graph(idtype): eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst)))) assert eset == {(0, 2)} + @parametrize_idtype def test_module_khop_graph(idtype): transform = dgl.KHopGraph(2) @@ -2213,6 +2427,7 @@ def test_module_khop_graph(idtype): eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst)))) assert eset == {(0, 2)} + @parametrize_idtype def test_module_add_metapaths(idtype): g = dgl.heterograph({ @@ -2225,8 +2440,10 @@ def test_module_add_metapaths(idtype): # Case1: keep_orig_edges is True metapaths = { - 'accepted': [('person', 'author', 'paper'), ('paper', 'accepted', 'venue')], - 'rejected': [('person', 'author', 'paper'), ('paper', 'rejected', 'venue')] + 'accepted': [('person', 'author', 'paper'), + ('paper', 'accepted', 'venue')], + 'rejected': [('person', 'author', 'paper'), + ('paper', 'rejected', 'venue')] } transform = dgl.AddMetaPaths(metapaths) new_g = transform(g) @@ -2242,8 +2459,10 @@ def test_module_add_metapaths(idtype): assert new_g.num_nodes(nty) == g.num_nodes(nty) for ety in g.canonical_etypes: assert new_g.num_edges(ety) == g.num_edges(ety) - assert F.allclose(g.nodes['venue'].data['h'], new_g.nodes['venue'].data['h']) - assert F.allclose(g.edges['author'].data['h'], new_g.edges['author'].data['h']) + assert F.allclose(g.nodes['venue'].data['h'], + new_g.nodes['venue'].data['h']) + assert F.allclose(g.edges['author'].data['h'], + new_g.edges['author'].data['h']) src, dst = new_g.edges(etype=('person', 'accepted', 'venue')) eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst)))) @@ -2262,7 +2481,8 @@ def test_module_add_metapaths(idtype): assert len(new_g.canonical_etypes) == 2 for nty in new_g.ntypes: assert new_g.num_nodes(nty) == g.num_nodes(nty) - assert F.allclose(g.nodes['venue'].data['h'], new_g.nodes['venue'].data['h']) + assert F.allclose(g.nodes['venue'].data['h'], + new_g.nodes['venue'].data['h']) src, dst = new_g.edges(etype=('person', 'accepted', 'venue')) eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst)))) @@ -2272,6 +2492,7 @@ def test_module_add_metapaths(idtype): eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst)))) assert eset == {(0, 1), (1, 1)} + @parametrize_idtype def test_module_compose(idtype): g = dgl.graph(([0, 1], [1, 2]), idtype=idtype, device=F.ctx()) @@ -2285,6 +2506,7 @@ def test_module_compose(idtype): eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst)))) assert eset == {(0, 1), (1, 2), (1, 0), (2, 1), (0, 0), (1, 1), (2, 2)} + @parametrize_idtype def test_module_gcnnorm(idtype): g = dgl.heterograph({ @@ -2297,13 +2519,17 @@ def test_module_gcnnorm(idtype): new_g = transform(g) assert 'w' not in new_g.edges[('A', 'r2', 'B')].data assert F.allclose(new_g.edges[('A', 'r1', 'A')].data['w'], - F.tensor([1./2, 1./math.sqrt(2), 0.])) - assert F.allclose(new_g.edges[('B', 'r3', 'B')].data['w'], F.tensor([1./3, 2./3, 0.])) + F.tensor([1. / 2, 1. / math.sqrt(2), 0.])) + assert F.allclose(new_g.edges[('B', 'r3', 'B')].data['w'], + F.tensor([1. / 3, 2. / 3, 0.])) -@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now') + +@unittest.skipIf(dgl.backend.backend_name != 'pytorch', + reason='Only support PyTorch for now') @parametrize_idtype def test_module_ppr(idtype): - g = dgl.graph(([0, 1, 2, 3, 4], [2, 3, 4, 5, 3]), idtype=idtype, device=F.ctx()) + g = dgl.graph(([0, 1, 2, 3, 4], [2, 3, 4, 5, 3]), idtype=idtype, + device=F.ctx()) g.ndata['h'] = F.randn((6, 2)) transform = dgl.PPR(avg_degree=2) new_g = transform(g) @@ -2313,7 +2539,8 @@ def test_module_ppr(idtype): src, dst = new_g.edges() eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst)))) assert eset == {(0, 0), (0, 2), (0, 4), (1, 1), (1, 3), (1, 5), (2, 2), - (2, 3), (2, 4), (3, 3), (3, 5), (4, 3), (4, 4), (4, 5), (5, 5)} + (2, 3), (2, 4), (3, 3), (3, 5), (4, 3), (4, 4), (4, 5), + (5, 5)} assert F.allclose(g.ndata['h'], new_g.ndata['h']) assert 'w' in new_g.edata @@ -2325,11 +2552,14 @@ def test_module_ppr(idtype): assert eset == {(0, 0), (1, 1), (1, 3), (2, 2), (2, 3), (2, 4), (3, 3), (3, 5), (4, 3), (4, 4), (4, 5), (5, 5)} -@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now') + +@unittest.skipIf(dgl.backend.backend_name != 'pytorch', + reason='Only support PyTorch for now') @parametrize_idtype def test_module_heat_kernel(idtype): # Case1: directed graph - g = dgl.graph(([0, 1, 2, 3, 4], [2, 3, 4, 5, 3]), idtype=idtype, device=F.ctx()) + g = dgl.graph(([0, 1, 2, 3, 4], [2, 3, 4, 5, 3]), idtype=idtype, + device=F.ctx()) g.ndata['h'] = F.randn((6, 2)) transform = dgl.HeatKernel(avg_degree=1) new_g = transform(g) @@ -2347,11 +2577,14 @@ def test_module_heat_kernel(idtype): eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst)))) assert eset == {(0, 0), (1, 1), (2, 2), (3, 3)} -@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now') + +@unittest.skipIf(dgl.backend.backend_name != 'pytorch', + reason='Only support PyTorch for now') @parametrize_idtype def test_module_gdc(idtype): transform = dgl.GDC([0.1, 0.2, 0.1], avg_degree=1) - g = dgl.graph(([0, 1, 2, 3, 4], [2, 3, 4, 5, 3]), idtype=idtype, device=F.ctx()) + g = dgl.graph(([0, 1, 2, 3, 4], [2, 3, 4, 5, 3]), idtype=idtype, + device=F.ctx()) g.ndata['h'] = F.randn((6, 2)) new_g = transform(g) assert new_g.idtype == g.idtype @@ -2359,7 +2592,8 @@ def test_module_gdc(idtype): assert new_g.num_nodes() == g.num_nodes() src, dst = new_g.edges() eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst)))) - assert eset == {(0, 0), (0, 2), (0, 4), (1, 1), (1, 3), (1, 5), (2, 2), (2, 3), + assert eset == {(0, 0), (0, 2), (0, 4), (1, 1), (1, 3), (1, 5), (2, 2), + (2, 3), (2, 4), (3, 3), (3, 5), (4, 3), (4, 4), (4, 5), (5, 5)} assert F.allclose(g.ndata['h'], new_g.ndata['h']) assert 'w' in new_g.edata @@ -2371,7 +2605,9 @@ def test_module_gdc(idtype): eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst)))) assert eset == {(0, 0), (1, 1), (2, 2), (3, 3), (4, 3), (4, 4), (5, 5)} -@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support a slicing operation") + +@unittest.skipIf(dgl.backend.backend_name == "tensorflow", + reason="TF doesn't support a slicing operation") @parametrize_idtype def test_module_node_shuffle(idtype): transform = dgl.NodeShuffle() @@ -2384,7 +2620,9 @@ def test_module_node_shuffle(idtype): new_nfeat = g.nodes['B'].data['h'] assert F.allclose(old_nfeat, new_nfeat) -@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now') + +@unittest.skipIf(dgl.backend.backend_name != 'pytorch', + reason='Only support PyTorch for now') @parametrize_idtype def test_module_drop_node(idtype): transform = dgl.DropNode() @@ -2401,7 +2639,9 @@ def test_module_drop_node(idtype): # Ensure that the original graph is not corrupted assert num_nodes_old == num_nodes_new -@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now') + +@unittest.skipIf(dgl.backend.backend_name != 'pytorch', + reason='Only support PyTorch for now') @parametrize_idtype def test_module_drop_edge(idtype): transform = dgl.DropEdge() @@ -2419,6 +2659,7 @@ def test_module_drop_edge(idtype): # Ensure that the original graph is not corrupted assert num_edges_old == num_edges_new + @parametrize_idtype def test_module_add_edge(idtype): transform = dgl.AddEdge() @@ -2438,33 +2679,37 @@ def test_module_add_edge(idtype): # Ensure that the original graph is not corrupted assert num_edges_old == num_edges_new + @parametrize_idtype def test_module_random_walk_pe(idtype): transform = dgl.RandomWalkPE(2, 'rwpe') g = dgl.graph(([0, 1, 1], [1, 1, 0]), idtype=idtype, device=F.ctx()) new_g = transform(g) - tgt = F.copy_to(F.tensor([[0., 0.5],[0.5, 0.75]]), g.device) + tgt = F.copy_to(F.tensor([[0., 0.5], [0.5, 0.75]]), g.device) assert F.allclose(new_g.ndata['rwpe'], tgt) + @parametrize_idtype def test_module_laplacian_pe(idtype): - g = dgl.graph(([2, 1, 0, 3, 1, 1],[3, 1, 1, 2, 1, 0]), idtype=idtype, device=F.ctx()) - tgt_eigval = F.copy_to(F.repeat(F.tensor([[1.1534e-17, 1.3333e+00, 2., np.nan, np.nan]]), - g.num_nodes(), dim=0), g.device) + g = dgl.graph(([2, 1, 0, 3, 1, 1], [3, 1, 1, 2, 1, 0]), idtype=idtype, + device=F.ctx()) + tgt_eigval = F.copy_to( + F.repeat(F.tensor([[1.1534e-17, 1.3333e+00, 2., np.nan, np.nan]]), + g.num_nodes(), dim=0), g.device) tgt_pe = F.copy_to(F.tensor([[0.5, 0.86602539, 0., 0., 0.], - [0.86602539, 0.5, 0., 0., 0.], - [0., 0., 0.70710677, 0., 0.], - [0., 0., 0.70710677, 0., 0.]]), g.device) + [0.86602539, 0.5, 0., 0., 0.], + [0., 0., 0.70710677, 0., 0.], + [0., 0., 0.70710677, 0., 0.]]), g.device) # without padding (k=n) transform = dgl.LaplacianPE(5, feat_name='lappe', padding=True) @@ -2477,18 +2722,21 @@ def test_module_laplacian_pe(idtype): assert F.allclose(new_g.ndata['lappe'].abs(), tgt_pe) # with eigenvalues - transform = dgl.LaplacianPE(5, feat_name='lappe', eigval_name='eigval', padding=True) + transform = dgl.LaplacianPE(5, feat_name='lappe', eigval_name='eigval', + padding=True) new_g = transform(g) # tensorflow has no abs() api if dgl.backend.backend_name == 'tensorflow': - assert F.allclose(new_g.ndata['eigval'][:,:3], tgt_eigval[:,:3]) + assert F.allclose(new_g.ndata['eigval'][:, :3], tgt_eigval[:, :3]) assert F.allclose(new_g.ndata['lappe'].__abs__(), tgt_pe) # pytorch & mxnet else: - assert F.allclose(new_g.ndata['eigval'][:,:3], tgt_eigval[:,:3]) + assert F.allclose(new_g.ndata['eigval'][:, :3], tgt_eigval[:, :3]) assert F.allclose(new_g.ndata['lappe'].abs(), tgt_pe) -@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now') + +@unittest.skipIf(dgl.backend.backend_name != 'pytorch', + reason='Only support PyTorch for now') @pytest.mark.parametrize('g', get_cases(['has_scalar_e_feature'])) def test_module_sign(g): import torch @@ -2512,7 +2760,8 @@ def test_module_sign(g): target = torch.matmul(adj, g.ndata['h']) assert torch.allclose(g.ndata['out_feat_1'], target, atol=atol) - transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', eweight_name='scalar_w', diffuse_op='raw') + transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', + eweight_name='scalar_w', diffuse_op='raw') g = transform(g) target = torch.matmul(weight_adj, g.ndata['h']) assert torch.allclose(g.ndata['out_feat_1'], target, atol=atol) @@ -2524,8 +2773,10 @@ def test_module_sign(g): target = torch.matmul(adj_rw, g.ndata['h']) assert torch.allclose(g.ndata['out_feat_1'], target, atol=atol) - weight_adj_rw = torch.matmul(torch.diag(1 / weight_adj.sum(dim=1)), weight_adj) - transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', eweight_name='scalar_w', diffuse_op='rw') + weight_adj_rw = torch.matmul(torch.diag(1 / weight_adj.sum(dim=1)), + weight_adj) + transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', + eweight_name='scalar_w', diffuse_op='rw') g = transform(g) target = torch.matmul(weight_adj_rw, g.ndata['h']) assert torch.allclose(g.ndata['out_feat_1'], target, atol=atol) @@ -2554,23 +2805,30 @@ def test_module_sign(g): # ppr alpha = 0.2 - transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', diffuse_op='ppr', alpha=alpha) + transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', diffuse_op='ppr', + alpha=alpha) g = transform(g) - target = (1 - alpha) * torch.matmul(adj_gcn, g.ndata['h']) + alpha * g.ndata['h'] + target = (1 - alpha) * torch.matmul(adj_gcn, g.ndata['h']) + alpha * \ + g.ndata['h'] assert torch.allclose(g.ndata['out_feat_1'], target, atol=atol) - transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', eweight_name='scalar_w', + transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', + eweight_name='scalar_w', diffuse_op='ppr', alpha=alpha) g = transform(g) - target = (1 - alpha) * torch.matmul(weight_adj_gcn, g.ndata['h']) + alpha * g.ndata['h'] + target = (1 - alpha) * torch.matmul(weight_adj_gcn, g.ndata['h']) + alpha * \ + g.ndata['h'] assert torch.allclose(g.ndata['out_feat_1'], target, atol=atol) -@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now') + +@unittest.skipIf(dgl.backend.backend_name != 'pytorch', + reason='Only support PyTorch for now') @parametrize_idtype def test_module_row_feat_normalizer(idtype): # Case1: Normalize features of a homogeneous graph. transform = dgl.RowFeatNormalizer(subtract_min=True, - node_feat_names=['h'], edge_feat_names=['w']) + node_feat_names=['h'], + edge_feat_names=['w']) g = dgl.rand_graph(5, 5, idtype=idtype, device=F.ctx()) g.ndata['h'] = F.randn((g.num_nodes(), 128)) g.edata['w'] = F.randn((g.num_edges(), 128)) @@ -2582,14 +2840,16 @@ def test_module_row_feat_normalizer(idtype): # Case2: Normalize features of a heterogeneous graph. transform = dgl.RowFeatNormalizer(subtract_min=True, - node_feat_names=['h', 'h2'], edge_feat_names=['w']) + node_feat_names=['h', 'h2'], + edge_feat_names=['w']) g = dgl.heterograph({ ('user', 'follows', 'user'): (F.tensor([1, 2]), F.tensor([3, 4])), ('player', 'plays', 'game'): (F.tensor([2, 2]), F.tensor([1, 1])) }, idtype=idtype, device=F.ctx()) g.ndata['h'] = {'game': F.randn((2, 128)), 'player': F.randn((3, 128))} g.ndata['h2'] = {'user': F.randn((5, 128))} - g.edata['w'] = {('user', 'follows', 'user'): F.randn((2, 128)), ('player', 'plays', 'game'): F.randn((2, 128))} + g.edata['w'] = {('user', 'follows', 'user'): F.randn((2, 128)), + ('player', 'plays', 'game'): F.randn((2, 128))} g = transform(g) assert g.ndata['h']['game'].shape == (2, 128) assert g.ndata['h']['player'].shape == (3, 128) @@ -2598,11 +2858,16 @@ def test_module_row_feat_normalizer(idtype): assert g.edata['w'][('player', 'plays', 'game')].shape == (2, 128) assert F.allclose(g.ndata['h']['game'].sum(1), F.tensor([1.0, 1.0])) assert F.allclose(g.ndata['h']['player'].sum(1), F.tensor([1.0, 1.0, 1.0])) - assert F.allclose(g.ndata['h2']['user'].sum(1), F.tensor([1.0, 1.0, 1.0, 1.0, 1.0])) - assert F.allclose(g.edata['w'][('user', 'follows', 'user')].sum(1), F.tensor([1.0, 1.0])) - assert F.allclose(g.edata['w'][('player', 'plays', 'game')].sum(1), F.tensor([1.0, 1.0])) + assert F.allclose(g.ndata['h2']['user'].sum(1), + F.tensor([1.0, 1.0, 1.0, 1.0, 1.0])) + assert F.allclose(g.edata['w'][('user', 'follows', 'user')].sum(1), + F.tensor([1.0, 1.0])) + assert F.allclose(g.edata['w'][('player', 'plays', 'game')].sum(1), + F.tensor([1.0, 1.0])) + -@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now') +@unittest.skipIf(dgl.backend.backend_name != 'pytorch', + reason='Only support PyTorch for now') @parametrize_idtype def test_module_feat_mask(idtype): # Case1: Mask node and edge feature tensors of a homogeneous graph. @@ -2632,6 +2897,7 @@ def test_module_feat_mask(idtype): assert g.edata['w'][('user', 'follows', 'user')].shape == (2, 5) assert g.edata['w'][('player', 'plays', 'game')].shape == (2, 5) + @parametrize_idtype def test_shortest_dist(idtype): g = dgl.graph(([0, 1, 1, 2], [2, 0, 3, 3]), idtype=idtype, device=F.ctx()) @@ -2664,6 +2930,7 @@ def test_shortest_dist(idtype): assert F.array_equal(dist, tgt_dist) assert F.array_equal(paths, tgt_paths) + @parametrize_idtype def test_module_to_levi(idtype): transform = dgl.ToLevi() @@ -2692,6 +2959,43 @@ def test_module_to_levi(idtype): assert F.allclose(lg.nodes['node'].data['h'], g.ndata['h']) assert F.allclose(lg.nodes['edge'].data['w'], g.edata['w']) -if __name__ == '__main__': + +@parametrize_idtype +def test_module_svd_pe(idtype): + g = dgl.graph( + ( + [0, 0, 1, 1, 2, 2, 2, 2, 3, 3, 4, 4], + [2, 3, 0, 2, 0, 2, 3, 4, 3, 4, 0, 1] + ), + idtype=idtype, + device=F.ctx(), + ) + # without padding + tgt_pe = F.copy_to( + F.tensor( + [ + [0.6669, 0.3068, 0.7979, 0.8477], + [0.6311, 0.6101, 0.1248, 0.5137], + [1.1993, 0.0665, 0.9183, 0.1455], + [0.5682, 0.6766, 0.8952, 0.6449], + [0.3393, 0.8363, 0.6500, 0.4564], + ] + ), + g.device, + ) + transform_1 = dgl.SVDPE(k=2, feat_name="svd_pe") + g1 = transform_1(g) + if dgl.backend.backend_name == "tensorflow": + assert F.allclose(g1.ndata["svd_pe"].__abs__(), tgt_pe) + else: + assert F.allclose(g1.ndata["svd_pe"].abs(), tgt_pe) + + # with padding + transform_2 = dgl.SVDPE(k=6, feat_name="svd_pe", padding=True) + g2 = transform_2(g) + assert F.shape(g2.ndata["svd_pe"]) == (5, 12) + + +if __name__ == "__main__": test_partition_with_halo() test_module_heat_kernel(F.int32)