-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
181 lines (157 loc) · 5.92 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import argparse
import os
import pickle
import dgl
import evaluation
import layers
import numpy as np
import sampler as sampler_module
import torch
import torch.nn as nn
import torchtext
import tqdm
from torch.utils.data import DataLoader
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
class PinSAGEModel(nn.Module):
def __init__(self, full_graph, ntype, textsets, hidden_dims, n_layers):
super().__init__()
self.proj = layers.LinearProjector(
full_graph, ntype, textsets, hidden_dims
)
self.sage = layers.SAGENet(hidden_dims, n_layers)
self.scorer = layers.ItemToItemScorer(full_graph, ntype)
def forward(self, pos_graph, neg_graph, blocks):
h_item = self.get_repr(blocks)
pos_score = self.scorer(pos_graph, h_item)
neg_score = self.scorer(neg_graph, h_item)
return (neg_score - pos_score + 1).clamp(min=0)
def get_repr(self, blocks):
h_item = self.proj(blocks[0].srcdata)
h_item_dst = self.proj(blocks[-1].dstdata)
return h_item_dst + self.sage(blocks, h_item)
def train(dataset, args):
g = dataset["train-graph"]
val_matrix = dataset["val-matrix"].tocsr()
test_matrix = dataset["test-matrix"].tocsr()
item_texts = dataset["item-texts"]
user_ntype = dataset["user-type"]
item_ntype = dataset["item-type"]
user_to_item_etype = dataset["user-to-item-type"]
timestamp = dataset["timestamp-edge-column"]
device = torch.device(args.device)
# Assign user and movie IDs and use them as features (to learn an individual trainable
# embedding for each entity)
g.nodes[user_ntype].data["id"] = torch.arange(g.num_nodes(user_ntype))
g.nodes[item_ntype].data["id"] = torch.arange(g.num_nodes(item_ntype))
# Prepare torchtext dataset and Vocabulary
textset = {}
tokenizer = get_tokenizer(None)
textlist = []
batch_first = True
for i in range(g.num_nodes(item_ntype)):
for key in item_texts.keys():
l = tokenizer(item_texts[key][i].lower())
textlist.append(l)
for key, field in item_texts.items():
vocab2 = build_vocab_from_iterator(
textlist, specials=["<unk>", "<pad>"]
)
textset[key] = (
textlist,
vocab2,
vocab2.get_stoi()["<pad>"],
batch_first,
)
# Sampler
batch_sampler = sampler_module.ItemToItemBatchSampler(
g, user_ntype, item_ntype, args.batch_size
)
neighbor_sampler = sampler_module.NeighborSampler(
g,
user_ntype,
item_ntype,
args.random_walk_length,
args.random_walk_restart_prob,
args.num_random_walks,
args.num_neighbors,
args.num_layers,
)
collator = sampler_module.PinSAGECollator(
neighbor_sampler, g, item_ntype, textset
)
dataloader = DataLoader(
batch_sampler,
collate_fn=collator.collate_train,
num_workers=args.num_workers,
)
dataloader_test = DataLoader(
torch.arange(g.num_nodes(item_ntype)),
batch_size=args.batch_size,
collate_fn=collator.collate_test,
num_workers=args.num_workers,
)
dataloader_it = iter(dataloader)
# Model
model = PinSAGEModel(
g, item_ntype, textset, args.hidden_dims, args.num_layers
).to(device)
# Optimizer
opt = torch.optim.Adam(model.parameters(), lr=args.lr)
# For each batch of head-tail-negative triplets...
for epoch_id in range(args.num_epochs):
model.train()
for batch_id in tqdm.trange(args.batches_per_epoch):
pos_graph, neg_graph, blocks = next(dataloader_it)
# Copy to GPU
for i in range(len(blocks)):
blocks[i] = blocks[i].to(device)
pos_graph = pos_graph.to(device)
neg_graph = neg_graph.to(device)
loss = model(pos_graph, neg_graph, blocks).mean()
opt.zero_grad()
loss.backward()
opt.step()
# Evaluate
model.eval()
with torch.no_grad():
item_batches = torch.arange(g.num_nodes(item_ntype)).split(
args.batch_size
)
h_item_batches = []
for blocks in dataloader_test:
for i in range(len(blocks)):
blocks[i] = blocks[i].to(device)
h_item_batches.append(model.get_repr(blocks))
h_item = torch.cat(h_item_batches, 0)
print(
evaluation.evaluate_nn(dataset, h_item, args.k, args.batch_size)
)
if __name__ == "__main__":
# Arguments
parser = argparse.ArgumentParser()
parser.add_argument("dataset_path", type=str)
parser.add_argument("--random-walk-length", type=int, default=2)
parser.add_argument("--random-walk-restart-prob", type=float, default=0.5)
parser.add_argument("--num-random-walks", type=int, default=10)
parser.add_argument("--num-neighbors", type=int, default=3)
parser.add_argument("--num-layers", type=int, default=2)
parser.add_argument("--hidden-dims", type=int, default=16)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument(
"--device", type=str, default="cpu"
) # can also be "cuda:0"
parser.add_argument("--num-epochs", type=int, default=1)
parser.add_argument("--batches-per-epoch", type=int, default=20000)
parser.add_argument("--num-workers", type=int, default=0)
parser.add_argument("--lr", type=float, default=3e-5)
parser.add_argument("-k", type=int, default=10)
args = parser.parse_args()
# Load dataset
data_info_path = os.path.join(args.dataset_path, "data.pkl")
with open(data_info_path, "rb") as f:
dataset = pickle.load(f)
train_g_path = os.path.join(args.dataset_path, "train_g.bin")
g_list, _ = dgl.load_graphs(train_g_path)
dataset["train-graph"] = g_list[0]
train(dataset, args)