diff --git a/python/caffe/test/test_net.py b/python/caffe/test/test_net.py new file mode 100644 index 00000000000..f0e9deef19c --- /dev/null +++ b/python/caffe/test/test_net.py @@ -0,0 +1,77 @@ +import unittest +import tempfile +import os +import numpy as np + +import caffe + +def simple_net_file(num_output): + """Make a simple net prototxt, based on test_net.cpp, returning the name + of the (temporary) file.""" + + f = tempfile.NamedTemporaryFile(delete=False) + f.write("""name: 'testnet' force_backward: true + layers { type: DUMMY_DATA name: 'data' top: 'data' top: 'label' + dummy_data_param { num: 5 channels: 2 height: 3 width: 4 + num: 5 channels: 1 height: 1 width: 1 + data_filler { type: 'gaussian' std: 1 } + data_filler { type: 'constant' } } } + layers { type: CONVOLUTION name: 'conv' bottom: 'data' top: 'conv' + convolution_param { num_output: 11 kernel_size: 2 pad: 3 + weight_filler { type: 'gaussian' std: 1 } + bias_filler { type: 'constant' value: 2 } } + weight_decay: 1 weight_decay: 0 } + layers { type: INNER_PRODUCT name: 'ip' bottom: 'conv' top: 'ip' + inner_product_param { num_output: """ + str(num_output) + """ + weight_filler { type: 'gaussian' std: 2.5 } + bias_filler { type: 'constant' value: -3 } } } + layers { type: SOFTMAX_LOSS name: 'loss' bottom: 'ip' bottom: 'label' + top: 'loss' }""") + f.close() + return f.name + +class TestNet(unittest.TestCase): + def setUp(self): + self.num_output = 13 + net_file = simple_net_file(self.num_output) + self.net = caffe.Net(net_file) + # fill in valid labels + self.net.blobs['label'].data[...] = \ + np.random.randint(self.num_output, + size=self.net.blobs['label'].data.shape) + os.remove(net_file) + + def test_memory(self): + """Check that holding onto blob data beyond the life of a Net is OK""" + + params = sum(map(list, self.net.params.itervalues()), []) + blobs = self.net.blobs.values() + del self.net + + # now sum everything (forcing all memory to be read) + total = 0 + for p in params: + total += p.data.sum() + p.diff.sum() + for bl in blobs: + total += bl.data.sum() + bl.diff.sum() + + def test_forward_backward(self): + self.net.forward() + self.net.backward() + + def test_inputs_outputs(self): + self.assertEqual(self.net.inputs, []) + self.assertEqual(self.net.outputs, ['loss']) + + def test_save_and_read(self): + f = tempfile.NamedTemporaryFile(delete=False) + f.close() + self.net.save(f.name) + net_file = simple_net_file(self.num_output) + net2 = caffe.Net(net_file, f.name) + os.remove(net_file) + os.remove(f.name) + for name in self.net.params: + for i in range(len(self.net.params[name])): + self.assertEqual(abs(self.net.params[name][i].data + - net2.params[name][i].data).sum(), 0) diff --git a/python/caffe/test/test_solver.py b/python/caffe/test/test_solver.py new file mode 100644 index 00000000000..b78c91f9978 --- /dev/null +++ b/python/caffe/test/test_solver.py @@ -0,0 +1,49 @@ +import unittest +import tempfile +import os +import numpy as np + +import caffe +from test_net import simple_net_file + +class TestSolver(unittest.TestCase): + def setUp(self): + self.num_output = 13 + net_f = simple_net_file(self.num_output) + f = tempfile.NamedTemporaryFile(delete=False) + f.write("""net: '""" + net_f + """' + test_iter: 10 test_interval: 10 base_lr: 0.01 momentum: 0.9 + weight_decay: 0.0005 lr_policy: 'inv' gamma: 0.0001 power: 0.75 + display: 100 max_iter: 100 snapshot_after_train: false""") + f.close() + self.solver = caffe.SGDSolver(f.name) + self.solver.net.set_mode_cpu() + # fill in valid labels + self.solver.net.blobs['label'].data[...] = \ + np.random.randint(self.num_output, + size=self.solver.net.blobs['label'].data.shape) + self.solver.test_nets[0].blobs['label'].data[...] = \ + np.random.randint(self.num_output, + size=self.solver.test_nets[0].blobs['label'].data.shape) + os.remove(f.name) + os.remove(net_f) + + def test_solve(self): + self.assertEqual(self.solver.iter, 0) + self.solver.solve() + self.assertEqual(self.solver.iter, 100) + + def test_net_memory(self): + """Check that nets survive after the solver is destroyed.""" + + nets = [self.solver.net] + list(self.solver.test_nets) + self.assertEqual(len(nets), 2) + del self.solver + + total = 0 + for net in nets: + for ps in net.params.itervalues(): + for p in ps: + total += p.data.sum() + p.diff.sum() + for bl in net.blobs.itervalues(): + total += bl.data.sum() + bl.diff.sum()