-
Notifications
You must be signed in to change notification settings - Fork 3k
/
train.py
338 lines (306 loc) · 10 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
#!/usr/bin/env python
# coding: utf-8
import argparse
import time
import dgl
import dgl.function as fn
import dgl.nn as dglnn
import numpy as np
import ogb
import torch
import torch.nn as nn
import torch.nn.functional as F
import tqdm
from ogb.lsc import MAG240MDataset, MAG240MEvaluator
class RGAT(nn.Module):
def __init__(
self,
in_channels,
out_channels,
hidden_channels,
num_etypes,
num_layers,
num_heads,
dropout,
pred_ntype,
):
super().__init__()
self.convs = nn.ModuleList()
self.norms = nn.ModuleList()
self.skips = nn.ModuleList()
self.convs.append(
nn.ModuleList(
[
dglnn.GATConv(
in_channels,
hidden_channels // num_heads,
num_heads,
allow_zero_in_degree=True,
)
for _ in range(num_etypes)
]
)
)
self.norms.append(nn.BatchNorm1d(hidden_channels))
self.skips.append(nn.Linear(in_channels, hidden_channels))
for _ in range(num_layers - 1):
self.convs.append(
nn.ModuleList(
[
dglnn.GATConv(
hidden_channels,
hidden_channels // num_heads,
num_heads,
allow_zero_in_degree=True,
)
for _ in range(num_etypes)
]
)
)
self.norms.append(nn.BatchNorm1d(hidden_channels))
self.skips.append(nn.Linear(hidden_channels, hidden_channels))
self.mlp = nn.Sequential(
nn.Linear(hidden_channels, hidden_channels),
nn.BatchNorm1d(hidden_channels),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_channels, out_channels),
)
self.dropout = nn.Dropout(dropout)
self.hidden_channels = hidden_channels
self.pred_ntype = pred_ntype
self.num_etypes = num_etypes
def forward(self, mfgs, x):
for i in range(len(mfgs)):
mfg = mfgs[i]
x_dst = x[: mfg.num_dst_nodes()]
n_src = mfg.num_src_nodes()
n_dst = mfg.num_dst_nodes()
mfg = dgl.block_to_graph(mfg)
x_skip = self.skips[i](x_dst)
for j in range(self.num_etypes):
subg = mfg.edge_subgraph(
mfg.edata["etype"] == j, relabel_nodes=False
)
x_skip += self.convs[i][j](subg, (x, x_dst)).view(
-1, self.hidden_channels
)
x = self.norms[i](x_skip)
x = F.elu(x)
x = self.dropout(x)
return self.mlp(x)
class ExternalNodeCollator(dgl.dataloading.NodeCollator):
def __init__(self, g, idx, sampler, offset, feats, label):
super().__init__(g, idx, sampler)
self.offset = offset
self.feats = feats
self.label = label
def collate(self, items):
input_nodes, output_nodes, mfgs = super().collate(items)
# Copy input features
mfgs[0].srcdata["x"] = torch.FloatTensor(self.feats[input_nodes])
mfgs[-1].dstdata["y"] = torch.LongTensor(
self.label[output_nodes - self.offset]
)
return input_nodes, output_nodes, mfgs
def train(args, dataset, g, feats, paper_offset):
print("Loading masks and labels")
train_idx = torch.LongTensor(dataset.get_idx_split("train")) + paper_offset
valid_idx = torch.LongTensor(dataset.get_idx_split("valid")) + paper_offset
label = dataset.paper_label
print("Initializing dataloader...")
sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 25])
train_collator = ExternalNodeCollator(
g, train_idx, sampler, paper_offset, feats, label
)
valid_collator = ExternalNodeCollator(
g, valid_idx, sampler, paper_offset, feats, label
)
train_dataloader = torch.utils.data.DataLoader(
train_collator.dataset,
batch_size=1024,
shuffle=True,
drop_last=False,
collate_fn=train_collator.collate,
num_workers=4,
)
valid_dataloader = torch.utils.data.DataLoader(
valid_collator.dataset,
batch_size=1024,
shuffle=True,
drop_last=False,
collate_fn=valid_collator.collate,
num_workers=2,
)
print("Initializing model...")
model = RGAT(
dataset.num_paper_features,
dataset.num_classes,
1024,
5,
2,
4,
0.5,
"paper",
).cuda()
opt = torch.optim.Adam(model.parameters(), lr=0.001)
sched = torch.optim.lr_scheduler.StepLR(opt, step_size=25, gamma=0.25)
best_acc = 0
for _ in range(args.epochs):
model.train()
with tqdm.tqdm(train_dataloader) as tq:
for i, (input_nodes, output_nodes, mfgs) in enumerate(tq):
mfgs = [g.to("cuda") for g in mfgs]
x = mfgs[0].srcdata["x"]
y = mfgs[-1].dstdata["y"]
y_hat = model(mfgs, x)
loss = F.cross_entropy(y_hat, y)
opt.zero_grad()
loss.backward()
opt.step()
acc = (y_hat.argmax(1) == y).float().mean()
tq.set_postfix(
{"loss": "%.4f" % loss.item(), "acc": "%.4f" % acc.item()},
refresh=False,
)
model.eval()
correct = total = 0
for i, (input_nodes, output_nodes, mfgs) in enumerate(
tqdm.tqdm(valid_dataloader)
):
with torch.no_grad():
mfgs = [g.to("cuda") for g in mfgs]
x = mfgs[0].srcdata["x"]
y = mfgs[-1].dstdata["y"]
y_hat = model(mfgs, x)
correct += (y_hat.argmax(1) == y).sum().item()
total += y_hat.shape[0]
acc = correct / total
print("Validation accuracy:", acc)
sched.step()
if best_acc < acc:
best_acc = acc
print("Updating best model...")
torch.save(model.state_dict(), args.model_path)
def test(args, dataset, g, feats, paper_offset):
print("Loading masks and labels...")
valid_idx = torch.LongTensor(dataset.get_idx_split("valid")) + paper_offset
test_idx = torch.LongTensor(dataset.get_idx_split("test")) + paper_offset
label = dataset.paper_label
print("Initializing data loader...")
sampler = dgl.dataloading.MultiLayerNeighborSampler([160, 160])
valid_collator = ExternalNodeCollator(
g, valid_idx, sampler, paper_offset, feats, label
)
valid_dataloader = torch.utils.data.DataLoader(
valid_collator.dataset,
batch_size=16,
shuffle=False,
drop_last=False,
collate_fn=valid_collator.collate,
num_workers=2,
)
test_collator = ExternalNodeCollator(
g, test_idx, sampler, paper_offset, feats, label
)
test_dataloader = torch.utils.data.DataLoader(
test_collator.dataset,
batch_size=16,
shuffle=False,
drop_last=False,
collate_fn=test_collator.collate,
num_workers=4,
)
print("Loading model...")
model = RGAT(
dataset.num_paper_features,
dataset.num_classes,
1024,
5,
2,
4,
0.5,
"paper",
).cuda()
model.load_state_dict(torch.load(args.model_path))
model.eval()
correct = total = 0
for i, (input_nodes, output_nodes, mfgs) in enumerate(
tqdm.tqdm(valid_dataloader)
):
with torch.no_grad():
mfgs = [g.to("cuda") for g in mfgs]
x = mfgs[0].srcdata["x"]
y = mfgs[-1].dstdata["y"]
y_hat = model(mfgs, x)
correct += (y_hat.argmax(1) == y).sum().item()
total += y_hat.shape[0]
acc = correct / total
print("Validation accuracy:", acc)
evaluator = MAG240MEvaluator()
y_preds = []
for i, (input_nodes, output_nodes, mfgs) in enumerate(
tqdm.tqdm(test_dataloader)
):
with torch.no_grad():
mfgs = [g.to("cuda") for g in mfgs]
x = mfgs[0].srcdata["x"]
y = mfgs[-1].dstdata["y"]
y_hat = model(mfgs, x)
y_preds.append(y_hat.argmax(1).cpu())
evaluator.save_test_submission(
{"y_pred": torch.cat(y_preds)}, args.submission_path
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--rootdir",
type=str,
default=".",
help="Directory to download the OGB dataset.",
)
parser.add_argument(
"--graph-path",
type=str,
default="./graph.dgl",
help="Path to the graph.",
)
parser.add_argument(
"--full-feature-path",
type=str,
default="./full.npy",
help="Path to the features of all nodes.",
)
parser.add_argument(
"--epochs", type=int, default=100, help="Number of epochs."
)
parser.add_argument(
"--model-path",
type=str,
default="./model.pt",
help="Path to store the best model.",
)
parser.add_argument(
"--submission-path",
type=str,
default="./results",
help="Submission directory.",
)
args = parser.parse_args()
dataset = MAG240MDataset(root=args.rootdir)
print("Loading graph")
(g,), _ = dgl.load_graphs(args.graph_path)
g = g.formats(["csc"])
print("Loading features")
paper_offset = dataset.num_authors + dataset.num_institutions
num_nodes = paper_offset + dataset.num_papers
num_features = dataset.num_paper_features
feats = np.memmap(
args.full_feature_path,
mode="r",
dtype="float16",
shape=(num_nodes, num_features),
)
if args.epochs != 0:
train(args, dataset, g, feats, paper_offset)
test(args, dataset, g, feats, paper_offset)