-
Notifications
You must be signed in to change notification settings - Fork 0
/
GCN_embedding.py
126 lines (110 loc) · 4.56 KB
/
GCN_embedding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F
import numpy as np
class GraphConv(nn.Module):
def __init__(self, input_dim, output_dim, add_self=False, normalize_embedding=False,
dropout=0.0, bias=True):
super(GraphConv, self).__init__()
self.add_self = add_self
self.dropout = dropout
if dropout > 0.001:
self.dropout_layer = nn.Dropout(p=dropout)
self.normalize_embedding = normalize_embedding
self.input_dim = input_dim
self.output_dim = output_dim
self.weight = nn.Parameter(torch.FloatTensor(input_dim, output_dim).cuda())
if bias:
self.bias = nn.Parameter(torch.FloatTensor(output_dim).cuda())
else:
self.bias = None
def forward(self, x, adj):
if self.dropout > 0.001:
x = self.dropout_layer(x)
y = torch.matmul(adj, x)
if self.add_self:
y += x
y = torch.matmul(y,self.weight)
if self.bias is not None:
y = y + self.bias
if self.normalize_embedding:
y = F.normalize(y, p=2, dim=2)
return y
class GcnEncoderGraph(nn.Module):
def __init__(self, input_dim, hidden_dim, embedding_dim, label_dim, num_layers,
pred_hidden_dims=[], concat=False, bn=True, dropout=0.0, args=None):
super(GcnEncoderGraph, self).__init__()
self.concat = concat
add_self = not concat
self.bn = bn
self.num_layers = num_layers
self.bias = True
if args is not None:
self.bias = args.bias
self.conv_first, self.conv_block, self.conv_last = self.build_conv_layers(
input_dim, hidden_dim, embedding_dim, num_layers,
add_self, normalize=True, dropout=dropout)
self.act = nn.ReLU()
self.label_dim = label_dim
if concat:
self.pred_input_dim = hidden_dim * (num_layers - 1) + embedding_dim
else:
self.pred_input_dim = embedding_dim
for m in self.modules():
if isinstance(m, GraphConv):
m.weight.data = init.kaiming_uniform_(m.weight.data, mode='fan_in', nonlinearity='relu')
if m.bias is not None:
m.bias.data = init.constant_(m.bias.data, 0.0)
def build_conv_layers(self, input_dim, hidden_dim, embedding_dim, num_layers, add_self,
normalize=False, dropout=0.0):
conv_first = GraphConv(input_dim=input_dim, output_dim=hidden_dim, add_self=add_self,
normalize_embedding=normalize, bias=self.bias)
conv_block = nn.ModuleList(
[GraphConv(input_dim=hidden_dim, output_dim=hidden_dim, add_self=add_self,
normalize_embedding=normalize, dropout=dropout, bias=self.bias)
for i in range(num_layers-2)])
conv_last = GraphConv(input_dim=hidden_dim, output_dim=embedding_dim, add_self=add_self,
normalize_embedding=normalize, bias=self.bias)
return conv_first, conv_block, conv_last
def apply_bn(self, x):
''' Batch normalization of 3D tensor x
'''
bn_module = nn.BatchNorm1d(x.size()[1]).cuda()
return bn_module(x)
def gcn_forward(self, x, adj, conv_first, conv_block, conv_last, embedding_mask=None):
''' Perform forward prop with graph convolution.
Returns:
Embedding matrix with dimension [batch_size x num_nodes x embedding]
'''
x = conv_first(x, adj)
x = self.act(x)#relu
if self.bn:
x = self.apply_bn(x)
x_all = [x]
for i in range(len(conv_block)):
x = conv_block[i](x,adj)
x = self.act(x)
if self.bn:
x = self.apply_bn(x)
x_all.append(x)
x = conv_last(x,adj)
x_all.append(x)
x_tensor = torch.cat(x_all, dim=2)
if embedding_mask is not None:
x_tensor = x_tensor * embedding_mask
return x_tensor
def forward(self, x, adj, batch_num_nodes=None, **kwargs):
# conv
x = self.conv_first(x, adj)
x = self.act(x)
if self.bn:
x = self.apply_bn(x)
for i in range(self.num_layers-2):
x = self.conv_block[i](x,adj)
x = self.act(x)
if self.bn:
x = self.apply_bn(x)
x = self.conv_last(x,adj)
return x