-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
73 lines (60 loc) · 2.69 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import tensorflow as tf
import src.thalamus as core
import experiments.map as world
import argparse
import src.util as util
import os
parser = argparse.ArgumentParser()
parser.add_argument("--reset", help="reset weight", action="store_true")
parser.add_argument("--rate", help="learning rate", type=float)
parser.add_argument("--total", help="total maps", type=int)
parser.add_argument("--gen", help="generate mode, ignore total flag", action="store_true")
args = parser.parse_args()
if __name__ == '__main__':
world_size = (60, 60)
past_steps = 2
components = 2
component_size = 30
belief_depth = 20
total_maps = 5 if not args.total else args.total
print "-----------------------"
print "world size: ", world_size
print "past steps: ", past_steps
print "components: ", components
print "component size: ", component_size
print "memory depth: ", belief_depth
print "total maps: ", total_maps
print "-----------------------"
sess = tf.Session()
machine = core.Machine(sess, world_size[0] * world_size[1], past_steps, components, component_size, belief_depth)
sess.run(tf.global_variables_initializer())
if not args.reset:
machine.load_session("./artifacts/demo")
if not args.gen:
for m in xrange(total_maps):
frames = world.get_valid_data(world_size, world_size[0] / 10, map_complexity=6, length_modifier=0.2)
machine.reset_memory()
for i in xrange(0, frames.shape[0]):
pasts = util.prepare_data(frames, i - past_steps, i)
input_data = util.prepare_data(frames, i, i + 1)
print "-----------"
# learn and save model
machine.learn(input_data, pasts, 20, "./artifacts/demo")
else:
generated_frames = []
frames = world.get_valid_data(world_size, world_size[0] / 10, map_complexity=6, length_modifier=0.2)
machine.reset_memory()
pasts = util.prepare_data(frames, 0 - past_steps, 0)
input_data = util.prepare_data(frames, 0, 0 + 1)
# Learn the target frame
machine.learn(input_data, pasts, 100)
for i in xrange(1, frames.shape[0]):
pasts = util.prepare_data(frames, i - past_steps, i)
input_data = util.prepare_data(frames, i, i + 1)
# generate thoughts
gen = machine.generate_thought(pasts)
# and also memorize the generated thoughts, but not save
machine.learn(gen, pasts, 100)
generated_frames.append(gen)
artifact_path = os.path.dirname(os.path.abspath(__file__)) + "/artifacts/"
world.toGif(world.to_numpy(generated_frames, world_size), artifact_path + "sample_path.gif")