forked from dmlc/dgl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
models.py
353 lines (322 loc) · 9.21 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
import torch
import torch.nn as nn
from dgl.nn.pytorch import (
AGNNConv,
APPNPConv,
ChebConv,
GATConv,
GINConv,
GraphConv,
SAGEConv,
SGConv,
TAGConv,
)
class GCN(nn.Module):
def __init__(
self, g, in_feats, n_classes, n_hidden, n_layers, activation, dropout
):
super(GCN, self).__init__()
self.g = g
self.layers = nn.ModuleList()
# input layer
self.layers.append(GraphConv(in_feats, n_hidden, activation=activation))
# hidden layers
for i in range(n_layers - 1):
self.layers.append(
GraphConv(n_hidden, n_hidden, activation=activation)
)
# output layer
self.layers.append(GraphConv(n_hidden, n_classes))
self.dropout = nn.Dropout(p=dropout)
def forward(self, features):
h = features
for i, layer in enumerate(self.layers):
if i != 0:
h = self.dropout(h)
h = layer(self.g, h)
return h
class GAT(nn.Module):
def __init__(
self,
g,
in_dim,
num_classes,
num_hidden,
num_layers,
heads,
activation,
feat_drop,
attn_drop,
negative_slope,
residual,
):
super(GAT, self).__init__()
self.g = g
self.num_layers = num_layers
self.gat_layers = nn.ModuleList()
self.activation = activation
# input projection (no residual)
self.gat_layers.append(
GATConv(
in_dim,
num_hidden,
heads[0],
feat_drop,
attn_drop,
negative_slope,
False,
self.activation,
)
)
# hidden layers
for l in range(1, num_layers):
# due to multi-head, the in_dim = num_hidden * num_heads
self.gat_layers.append(
GATConv(
num_hidden * heads[l - 1],
num_hidden,
heads[l],
feat_drop,
attn_drop,
negative_slope,
residual,
self.activation,
)
)
# output projection
self.gat_layers.append(
GATConv(
num_hidden * heads[-2],
num_classes,
heads[-1],
feat_drop,
attn_drop,
negative_slope,
residual,
None,
)
)
def forward(self, inputs):
h = inputs
for l in range(self.num_layers):
h = self.gat_layers[l](self.g, h).flatten(1)
# output projection
logits = self.gat_layers[-1](self.g, h).mean(1)
return logits
class GraphSAGE(nn.Module):
def __init__(
self,
g,
in_feats,
n_classes,
n_hidden,
n_layers,
activation,
dropout,
aggregator_type,
):
super(GraphSAGE, self).__init__()
self.layers = nn.ModuleList()
self.g = g
# input layer
self.layers.append(
SAGEConv(
in_feats,
n_hidden,
aggregator_type,
feat_drop=dropout,
activation=activation,
)
)
# hidden layers
for i in range(n_layers - 1):
self.layers.append(
SAGEConv(
n_hidden,
n_hidden,
aggregator_type,
feat_drop=dropout,
activation=activation,
)
)
# output layer
self.layers.append(
SAGEConv(
n_hidden,
n_classes,
aggregator_type,
feat_drop=dropout,
activation=None,
)
) # activation None
def forward(self, features):
h = features
for layer in self.layers:
h = layer(self.g, h)
return h
class APPNP(nn.Module):
def __init__(
self,
g,
in_feats,
n_classes,
n_hidden,
n_layers,
activation,
feat_drop,
edge_drop,
alpha,
k,
):
super(APPNP, self).__init__()
self.g = g
self.layers = nn.ModuleList()
# input layer
self.layers.append(nn.Linear(in_feats, n_hidden))
# hidden layers
for i in range(1, n_layers):
self.layers.append(nn.Linear(n_hidden, n_hidden))
# output layer
self.layers.append(nn.Linear(n_hidden, n_classes))
self.activation = activation
if feat_drop:
self.feat_drop = nn.Dropout(feat_drop)
else:
self.feat_drop = lambda x: x
self.propagate = APPNPConv(k, alpha, edge_drop)
self.reset_parameters()
def reset_parameters(self):
for layer in self.layers:
layer.reset_parameters()
def forward(self, features):
# prediction step
h = features
h = self.feat_drop(h)
h = self.activation(self.layers[0](h))
for layer in self.layers[1:-1]:
h = self.activation(layer(h))
h = self.layers[-1](self.feat_drop(h))
# propagation step
h = self.propagate(self.g, h)
return h
class TAGCN(nn.Module):
def __init__(
self, g, in_feats, n_classes, n_hidden, n_layers, activation, dropout
):
super(TAGCN, self).__init__()
self.g = g
self.layers = nn.ModuleList()
# input layer
self.layers.append(TAGConv(in_feats, n_hidden, activation=activation))
# hidden layers
for i in range(n_layers - 1):
self.layers.append(
TAGConv(n_hidden, n_hidden, activation=activation)
)
# output layer
self.layers.append(TAGConv(n_hidden, n_classes)) # activation=None
self.dropout = nn.Dropout(p=dropout)
def forward(self, features):
h = features
for i, layer in enumerate(self.layers):
if i != 0:
h = self.dropout(h)
h = layer(self.g, h)
return h
class AGNN(nn.Module):
def __init__(
self,
g,
in_feats,
n_classes,
n_hidden,
n_layers,
init_beta,
learn_beta,
dropout,
):
super(AGNN, self).__init__()
self.g = g
self.layers = nn.ModuleList(
[AGNNConv(init_beta, learn_beta) for _ in range(n_layers)]
)
self.proj = nn.Sequential(
nn.Dropout(dropout), nn.Linear(in_feats, n_hidden), nn.ReLU()
)
self.cls = nn.Sequential(
nn.Dropout(dropout), nn.Linear(n_hidden, n_classes)
)
def forward(self, features):
h = self.proj(features)
for layer in self.layers:
h = layer(self.g, h)
return self.cls(h)
class SGC(nn.Module):
def __init__(self, g, in_feats, n_classes, n_hidden, k, bias):
super(SGC, self).__init__()
self.g = g
self.net = SGConv(in_feats, n_classes, k=k, cached=True, bias=bias)
def forward(self, features):
return self.net(self.g, features)
class GIN(nn.Module):
def __init__(
self, g, in_feats, n_classes, n_hidden, n_layers, init_eps, learn_eps
):
super(GIN, self).__init__()
self.g = g
self.layers = nn.ModuleList()
self.layers.append(
GINConv(
nn.Sequential(
nn.Dropout(0.6),
nn.Linear(in_feats, n_hidden),
nn.ReLU(),
),
"mean",
init_eps,
learn_eps,
)
)
for i in range(n_layers - 1):
self.layers.append(
GINConv(
nn.Sequential(
nn.Dropout(0.6),
nn.Linear(n_hidden, n_hidden),
nn.ReLU(),
),
"mean",
init_eps,
learn_eps,
)
)
self.layers.append(
GINConv(
nn.Sequential(
nn.Dropout(0.6),
nn.Linear(n_hidden, n_classes),
),
"mean",
init_eps,
learn_eps,
)
)
def forward(self, features):
h = features
for layer in self.layers:
h = layer(self.g, h)
return h
class ChebNet(nn.Module):
def __init__(self, g, in_feats, n_classes, n_hidden, n_layers, k, bias):
super(ChebNet, self).__init__()
self.g = g
self.layers = nn.ModuleList()
self.layers.append(ChebConv(in_feats, n_hidden, k, bias=bias))
for _ in range(n_layers - 1):
self.layers.append(ChebConv(n_hidden, n_hidden, k, bias=bias))
self.layers.append(ChebConv(n_hidden, n_classes, k, bias=bias))
def forward(self, features):
h = features
for layer in self.layers:
h = layer(self.g, h, [2])
return h