-
Notifications
You must be signed in to change notification settings - Fork 0
/
cluster_gcn.py
123 lines (110 loc) · 3.74 KB
/
cluster_gcn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import time
import dgl
import dgl.nn as dglnn
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics.functional as MF
from ogb.nodeproppred import DglNodePropPredDataset
class SAGE(nn.Module):
def __init__(self, in_feats, n_hidden, n_classes):
super().__init__()
self.layers = nn.ModuleList()
self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, "mean"))
self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, "mean"))
self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, "mean"))
self.dropout = nn.Dropout(0.5)
def forward(self, sg, x):
h = x
for l, layer in enumerate(self.layers):
h = layer(sg, h)
if l != len(self.layers) - 1:
h = F.relu(h)
h = self.dropout(h)
return h
dataset = dgl.data.AsNodePredDataset(DglNodePropPredDataset("ogbn-products"))
graph = dataset[
0
] # already prepares ndata['label'/'train_mask'/'val_mask'/'test_mask']
model = SAGE(graph.ndata["feat"].shape[1], 256, dataset.num_classes).cuda()
opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
num_partitions = 1000
sampler = dgl.dataloading.ClusterGCNSampler(
graph,
num_partitions,
prefetch_ndata=["feat", "label", "train_mask", "val_mask", "test_mask"],
)
# DataLoader for generic dataloading with a graph, a set of indices (any indices, like
# partition IDs here), and a graph sampler.
dataloader = dgl.dataloading.DataLoader(
graph,
torch.arange(num_partitions).to("cuda"),
sampler,
device="cuda",
batch_size=100,
shuffle=True,
drop_last=False,
num_workers=0,
use_uva=True,
)
durations = []
for epoch in range(10):
t0 = time.time()
model.train()
for it, sg in enumerate(dataloader):
x = sg.ndata["feat"]
y = sg.ndata["label"]
m = sg.ndata["train_mask"].bool()
y_hat = model(sg, x)
loss = F.cross_entropy(y_hat[m], y[m])
opt.zero_grad()
loss.backward()
opt.step()
if it % 20 == 0:
acc = MF.accuracy(
y_hat[m],
y[m],
task="multiclass",
num_classes=dataset.num_classes,
)
mem = torch.cuda.max_memory_allocated() / 1000000
print("Loss", loss.item(), "Acc", acc.item(), "GPU Mem", mem, "MB")
tt = time.time() - t0
print("Run time for epoch# %d: %.2fs" % (epoch, tt))
durations.append(tt)
model.eval()
with torch.no_grad():
val_preds, test_preds = [], []
val_labels, test_labels = [], []
for it, sg in enumerate(dataloader):
x = sg.ndata["feat"]
y = sg.ndata["label"]
m_val = sg.ndata["val_mask"].bool()
m_test = sg.ndata["test_mask"].bool()
y_hat = model(sg, x)
val_preds.append(y_hat[m_val])
val_labels.append(y[m_val])
test_preds.append(y_hat[m_test])
test_labels.append(y[m_test])
val_preds = torch.cat(val_preds, 0)
val_labels = torch.cat(val_labels, 0)
test_preds = torch.cat(test_preds, 0)
test_labels = torch.cat(test_labels, 0)
val_acc = MF.accuracy(
val_preds,
val_labels,
task="multiclass",
num_classes=dataset.num_classes,
)
test_acc = MF.accuracy(
test_preds,
test_labels,
task="multiclass",
num_classes=dataset.num_classes,
)
print("Validation acc:", val_acc.item(), "Test acc:", test_acc.item())
print(
"Average run time for last %d epochs: %.2fs standard deviation: %.3f"
% ((epoch - 3), np.mean(durations[4:]), np.std(durations[4:]))
)