-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
117 lines (93 loc) · 3.72 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import argparse
import importlib
from triplets import *
from metrics import *
#parser for all arguments!
parser = argparse.ArgumentParser(description='Testing knowledge graph embeddings...')
#requirement arguments...
parser.add_argument("save_path",
type=str, help="Directory where model is saved")
parser.add_argument("algorithm",
type=str, help="Embedding algorithm of saved model to be tested!")
parser.add_argument("metric",
choices=['mean_rank', 'hits@'],
type=str, help="Metric to be used for testing!")
#optional requirements!
parser.add_argument("--N",
default=10,
type=int, help="hits@N N. Only used when hits@ is used as argument in metric")
parser.add_argument("--test_data",
default='./FB15k_237/test.txt',
type=str, help="Path to test data")
parser.add_argument("--filtering",
default = False,
type=bool, help="Filter out true triplets, that artificially lower the scores...")
parser.add_argument("--train_data",
default='./FB15k_237/train.txt',
type=str, help="Path to training data (used for filtering)")
parser.add_argument("--val_data",
default='./FB15k_237/valid.txt',
type=str, help="Path to validation data (used for filtering)")
parser.add_argument("--big",
default='10e5',
type=float, help="Value of mask, so as to filter out golden triples")
parser.add_argument("--batch_size",
default=64,
type=int, help="Test batch size")
parser.add_argument("--seed",
default=42,
type=int, help="Seed for randomness")
#finds all arguments...
args = parser.parse_args()
SAVE_PATH = args.save_path
ALGORITHM = args.algorithm
metric_ = args.metric
filtering = args.filtering
test_data = args.test_data
train_data = args.train_data
val_data = args.val_data
N = args.N
batch_size = args.batch_size
#seeds
torch.manual_seed(args.seed)
#import algorithm
module = importlib.import_module('algorithms.'+args.algorithm, ".")
#directory where triplets are stored... as well as ids!
id_dir=os.path.dirname(test_data)
#loading ids...
with open(id_dir+'/entity2id.json', 'r') as f:
unique_objects = json.load(f)
with open(id_dir+'/relationship2id.json', 'r') as f:
unique_relationships = json.load(f)
#data
if filtering:
#training
train = Triplets(path = train_data, unique_objects = unique_objects,
unique_relationships = unique_relationships)
#validation
val = Triplets(path = val_data, unique_objects = unique_objects,
unique_relationships = unique_relationships)
#test
test = Triplets(path = test_data, unique_objects = unique_objects,
unique_relationships = unique_relationships)
#load model...
model = module.Model(len(unique_objects), len(unique_relationships))
model = model.load(SAVE_PATH)
if filtering:
filter = Filter(train, val, test, big = args.big)
else:
filter = None
path=os.path.dirname(SAVE_PATH)
if metric_ == 'mean_rank':
result = mean_rank(test, model, filter = filter, batch_size = batch_size)
with open(path+"/test.txt", "a") as myfile:
s = 'Filt.' if filtering else 'Raw'
myfile.write(f'{s} mean rank = {"{:.2f}".format(result)}\n')
elif metric_ == 'hits@':
result = hits_at_N(test, model, N=N, filter = filter, batch_size = batch_size)
with open(path+"/test.txt", "a") as myfile:
s = 'Filt.' if filtering else 'Raw'
myfile.write(f'{s} hits@{N} = {"{:.2f}".format(result*100)}%\n')
else:
raise