-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataloader.py
260 lines (200 loc) · 8 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
import argparse
import random
import numpy as np
import scipy.sparse as sp
import scipy.sparse.csgraph
import sklearn.linear_model as sklm
import sklearn.metrics as skm
import sklearn.model_selection as skms
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from ignite.engine import Engine, Events
from ignite.handlers import ModelCheckpoint
from torch.utils.data import DataLoader, IterableDataset, get_worker_info
CHECKPOINT_PREFIX = "g2g"
def load_dataset(filename):
def load_sparse_matrix(file, name):
return sp.csr_matrix(
(file[f"{name}_data"], file[f"{name}_indices"], file[f"{name}_indptr"]),
shape=file[f"{name}_shape"],
dtype=np.float32,
)
with np.load(filename) as f:
A = load_sparse_matrix(f, "adj")
X = load_sparse_matrix(f, "attr")
z = f["labels"].astype(np.float32)
return A, X, z
def level_sets(A, K):
"""Enumerate the level sets for each node's neighborhood
Parameters
----------
A : np.array
Adjacency matrix
K : int?
Maximum path length to consider
All nodes that are further apart go into the last level set.
Returns
-------
{ node: [i -> i-hop neighborhood] }
"""
if A.shape[0] == 0 or A.shape[1] == 0:
return {}
# Compute the shortest path length between any two nodes
D = scipy.sparse.csgraph.shortest_path(
A, method="D", unweighted=True, directed=False
)
# Cast to int so that the distances can be used as indices
#
# D has inf for any pair of nodes from different cmponents and np.isfinite
# is really slow on individual numbers so we call it only once here
D[np.logical_not(np.isfinite(D))] = -1.0
D = D.astype(np.int32)
# Handle nodes farther than K as if they were unreachable
if K is not None:
D[D > K] = -1
# Read the level sets off the distance matrix
set_counts = D.max(axis=1)
sets = {i: [[] for _ in range(1 + set_counts[i] + 1)] for i in range(D.shape[0])}
for i in range(D.shape[0]):
sets[i][0].append(i)
for j in range(i):
d = D[i, j]
# If a node is unreachable, add it to the outermost level set. This
# trick ensures that nodes from different connected components get
# pushed apart and is essential to get good performance.
if d < 0:
sets[i][-1].append(j)
sets[j][-1].append(i)
else:
sets[i][d].append(j)
sets[j][d].append(i)
return sets
class CompleteKPartiteGraph:
"""A complete k-partite graph
"""
def __init__(self, partitions):
"""
Parameters
----------
partitions : [[int]]
List of node partitions where each partition is list of node IDs
"""
self.partitions = partitions
self.counts = np.array([len(p) for p in partitions])
self.total = self.counts.sum()
assert len(self.partitions) >= 2
assert np.all(self.counts > 0)
# Enumerate all nodes so that we can easily look them up with an index
# from 1..total
self.nodes = np.array([node for partition in partitions for node in partition])
# Precompute the partition count of each node
self.n_i = np.array(
[n for partition, n in zip(self.partitions, self.counts) for _ in partition]
)
# Precompute the start of each node's partition in self.nodes
self.start_i = np.array(
[
end - n
for partition, n, end in zip(
self.partitions, self.counts, self.counts.cumsum()
)
for node in partition
]
)
# Each node has edges to every other node except the ones in its own
# level set
self.out_degrees = np.full(self.total, self.total) - self.n_i
# Sample the first nodes proportionally to their out-degree
self.p = self.out_degrees / self.out_degrees.sum()
def sample_edges(self, size=1):
"""Sample edges (j, k) from this graph uniformly and independently
Returns
-------
([j], [k])
j will always be in a lower partition than k
"""
# Sample the originating nodes for each edge
j = np.random.choice(self.total, size=size, p=self.p, replace=True)
# For each j sample one outgoing edge uniformly
#
# Se we want to sample from 1..n \ start[j]...(start[j] + count[j]). We
# do this by sampling from 1..#degrees[j] and if we hit a node
k = np.random.randint(self.out_degrees[j])
filter = k >= self.start_i[j]
k += filter.astype(np.int) * self.n_i[j]
# Swap nodes such that the partition index of j is less than that of k
# for each edge
wrong_order = k < j
tmp = k[wrong_order]
k[wrong_order] = j[wrong_order]
j[wrong_order] = tmp
# Translate node indices back into user configured node IDs
j = self.nodes[j]
k = self.nodes[k]
return j, k
class AttributedGraph:
def __init__(self, A, X, z, K):
self.A = A
self.X = torch.tensor(X.toarray())
self.z = z
self.level_sets = level_sets(A, K)
# Precompute the cardinality of each level set for every node
self.level_counts = {
node: np.array(list(map(len, level_sets)))
for node, level_sets in self.level_sets.items()
}
# Precompute the weights of each node's expected value in the loss
N = self.level_counts
self.loss_weights = 0.5 * np.array(
[N[i][1:].sum() ** 2 - (N[i][1:] ** 2).sum() for i in self.nodes()]
)
n = self.A.shape[0]
self.neighborhoods = [None] * n
for i in range(n):
ls = self.level_sets[i]
if len(ls) >= 3:
self.neighborhoods[i] = CompleteKPartiteGraph(ls[1:])
def nodes(self):
return range(self.A.shape[0])
def eligible_nodes(self):
"""Nodes that can be used to compute the loss"""
N = self.level_counts
# If a node only has first-degree neighbors, the loss is undefined
return [i for i in self.nodes() if len(N[i]) >= 3]
def sample_two_neighbors(self, node, size=1):
"""Sample to nodes from the neighborhood of different rank"""
level_sets = self.level_sets[node]
if len(level_sets) < 3:
raise Exception(f"Node {node} has only one layer of neighbors")
return self.neighborhoods[node].sample_edges(size)
class GraphDataset(IterableDataset):
"""A dataset that generates all necessary information for one training step
Sampling the edges is actually the most expensive part of the whole training
loop and by putting it in the dataset generator, we can parallelize it
independently from the training loop.
"""
def __init__(self, graph, nsamples, iterations):
self.graph = graph
self.nsamples = nsamples
self.iterations = iterations
def __iter__(self):
graph = self.graph
nsamples = self.nsamples
eligible_nodes = list(graph.eligible_nodes())
nrows = len(eligible_nodes) * nsamples
weights = torch.empty(nrows)
for _ in range(self.iterations):
i_indices = torch.empty(nrows, dtype=torch.long)
j_indices = torch.empty(nrows, dtype=torch.long)
k_indices = torch.empty(nrows, dtype=torch.long)
for index, i in enumerate(eligible_nodes):
start = index * nsamples
end = start + nsamples
i_indices[start:end] = i
js, ks = graph.sample_two_neighbors(i, size=nsamples)
j_indices[start:end] = torch.tensor(js)
k_indices[start:end] = torch.tensor(ks)
weights[start:end] = graph.loss_weights[i]
yield graph.X, i_indices, j_indices, k_indices, weights, nsamples