-
Notifications
You must be signed in to change notification settings - Fork 2
/
test_balls.py
78 lines (61 loc) · 2.17 KB
/
test_balls.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os
import numpy as np
import theano.tensor as T
import theano
from scipy import misc
from autoencoder import Autoencoder
from transform import *
from scene import *
from shader import *
from optimize import *
if not os.path.exists('output'):
os.makedirs('output')
#train_data = np.array([misc.imread('example.png').flatten()], dtype='float32')/255.0
train_data = np.asarray([misc.imread('15.png').flatten()], dtype='float32')/255.0
N,D = train_data.shape
img_sz = int(np.sqrt(D))
def scene(capsules, obj_params):
shapes = []
#TODO move the material information to attribute of capsule instance
material1 = Material((0.2, 0.9, 0.4), 0.3, 0.7, 0.5, 50.)
for i in xrange(len(capsules)):
capsule = capsules[i]
obj_param = obj_params[i]
t1 = translate(obj_param[:3]) * scale(obj_param[3:])
if capsule.name == 'sphere':
shapes.append(Sphere(t1, material1))
elif capsule.name == 'square':
shapes.append(Square(t1, material1))
elif capsule.name == 'light':
shapes.append(Light(t1, material1))
light = Light((-1., -1., 2.), (0.961, 1., 0.87))
camera = Camera(img_sz, img_sz)
#shader = PhongShader()
shader = DepthMapShader(6.1)
scene = Scene(shapes, [light], camera, shader)
return scene.build()
#Hyper-parameters
num_capsule = 2
epsilon = 0.0001
num_epoch = 200
ae = Autoencoder(scene, D, 300, 30, 10, num_capsule)
opt = MGDAutoOptimizer(ae)
train_ae = opt.optimize(train_data)
get_recon = theano.function([], ae.get_reconstruct(train_data[0])[:,:,0])
get_center= theano.function([], ae.encoder(train_data[0]))
recon = get_recon()
center = get_center()[0]
imsave('output/test_balls0.png', recon)
print '...Initial center1 (%g,%g,%g)' % (center[0], center[1], center[2])
print recon.sum()
n=0;
while (n<num_epoch):
n+=1
eps = get_epsilon(epsilon, num_epoch, n)
train_loss = train_ae(eps)
center = get_center()[0]
print '...Epoch %d Train loss %g, Center (%g, %g, %g)' \
% (n, train_loss, center[0], center[1], center[2])
if n % 10 ==0:
image = get_recon()
imsave('output/test_balls%d.png' % (n,), image)