-
Notifications
You must be signed in to change notification settings - Fork 1
/
example_test_transe.py
executable file
·49 lines (48 loc) · 1.38 KB
/
example_test_transe.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import config
import models
import tensorflow as tf
import numpy as np
import json
import os
os.environ['CUDA_VISIBLE_DEVICES']='7'
# (1) Set import files and OpenKE will automatically load models via tf.Saver().
con = config.Config()
con.set_in_path("./benchmarks/FB15K/")
#con.set_test_link_prediction(True)
con.set_test_triple_classification(True)
con.set_work_threads(8)
con.set_dimension(100)
con.set_import_files("./res/model.vec.tf")
con.init()
con.set_model(models.TransE)
con.test()
con.predict_head_entity(152, 9, 5)
con.predict_tail_entity(151, 9, 5)
con.predict_relation(151, 152, 5)
con.predict_triple(151, 152, 9)
con.predict_triple(151, 152, 8)
#con.show_link_prediction(2,1)
#con.show_triple_classification(2,1,3)
# (2) Read model parameters from json files and manually load parameters.
# con = config.Config()
# con.set_in_path("./benchmarks/FB15K/")
# con.set_test_flag(True)
# con.set_work_threads(4)
# con.set_dimension(50)
# con.init()
# con.set_model(models.TransE)
# f = open("./res/embedding.vec.json", "r")
# content = json.loads(f.read())
# f.close()
# con.set_parameters(content)
# con.test()
# (3) Manually load models via tf.Saver().
# con = config.Config()
# con.set_in_path("./benchmarks/FB15K/")
# con.set_test_flag(True)
# con.set_work_threads(4)
# con.set_dimension(50)
# con.init()
# con.set_model(models.TransE)
# con.import_variables("./res/model.vec.tf")
# con.test()