-
Notifications
You must be signed in to change notification settings - Fork 28
/
layers.py
106 lines (94 loc) · 3.64 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# coding: utf-8
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
# K-core subgraph based diffusion layer
class CoreDiffusion(nn.Module):
input_dim: int
output_dim: int
layer_num: int
bias: bool
rnn_type: str
def __init__(self, input_dim, output_dim, core_num=1, bias=True, rnn_type='GRU'):
super(CoreDiffusion, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.bias = bias
self.core_num = core_num
self.rnn_type = rnn_type
self.linear = nn.Linear(input_dim, output_dim)
# self.att_weight = nn.Parameter(torch.FloatTensor(core_num))
assert self.rnn_type in ['LSTM', 'GRU']
if self.rnn_type == 'LSTM':
self.rnn = nn.LSTM(input_size=input_dim, hidden_size=output_dim, num_layers=1, bias=bias, batch_first=True)
else:
self.rnn = nn.GRU(input_size=input_dim, hidden_size=output_dim, num_layers=1, bias=bias, batch_first=True)
self.norm = nn.LayerNorm(output_dim)
# self.reset_parameters()
# def reset_parameters(self):
# # stdv = 1. / math.sqrt(self.weight.size(1))
# self.att_weight.data.uniform_(0, 1)
def forward(self, x, adj_list):
hx_list = []
# output = None
for i, adj in enumerate(adj_list):
if i == 0:
res = torch.sparse.mm(adj, x)
else:
res = hx_list[-1] + torch.sparse.mm(adj, x)
# hx = self.linear(res)
hx_list.append(res)
hx_list = [F.relu(res) for res in hx_list]
#################################
# Simple Core Diffusion, no RNN
# out = hx_list[0]
# for i, res in enumerate(hx_list[1:]):
# out = out + res
# output = self.linear(out)
##################################
# Add RNN to improve performance, but this will reduce the computation efficiency a little.
hx = torch.stack(hx_list, dim=0).transpose(0, 1) # [batch_size, core_num, input_dim]
output, _ = self.rnn(hx)
output = output.sum(dim=1)
# Layer normalization could improve performance and make rnn stable
output = self.norm(output)
return output
# Multi-Layer Perceptron(MLP) layer
class MLP(nn.Module):
input_dim: int
hidden_dim: int
output_dim: int
layer_num: int
bias: bool
activate_type: str
def __init__(self, input_dim, hidden_dim, output_dim, layer_num, bias=True, activate_type='N'):
super(MLP, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.output_dim = output_dim
self.layer_num = layer_num
self.bias = bias
self.activate_type = activate_type
assert self.activate_type in ['L', 'N']
assert self.layer_num > 0
if layer_num == 1:
self.linear = nn.Linear(input_dim, output_dim, bias=bias)
else:
self.linears = torch.nn.ModuleList()
self.linears.append(nn.Linear(input_dim, hidden_dim, bias=bias))
for layer in range(layer_num - 2):
self.linears.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
self.linears.append(nn.Linear(hidden_dim, output_dim, bias=bias))
def forward(self, x):
if self.layer_num == 1: # Linear model
x = self.linear(x)
if self.activate_type == 'N':
x = F.selu(x)
return x
h = x # MLP
for layer in range(self.layer_num):
h = self.linears[layer](h)
if self.activate_type == 'N':
h = F.selu(h)
return h