Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge pull request #34 from antinucleon/master
Browse files Browse the repository at this point in the history
Minor change, add MNIST to test
  • Loading branch information
antinucleon committed Aug 26, 2015
2 parents 1cc0ccb + d576bb3 commit 2ad67a3
Show file tree
Hide file tree
Showing 7 changed files with 231 additions and 61 deletions.
2 changes: 2 additions & 0 deletions include/mxnet/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ typedef mshadow::TBlob TBlob;
namespace dmlc {
// Add a few patches to support TShape in dmlc/parameter.
DMLC_DECLARE_TYPE_NAME(mxnet::TShape, "Shape(tuple)");
DMLC_DECLARE_TYPE_NAME(uint32_t, "unsigned int");


namespace parameter {
template<>
Expand Down
26 changes: 15 additions & 11 deletions src/operator/convolution-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class ConvolutionOp : public Operator {
// TODO(bing): make mshadow support dual stride
}
const index_t gstride = temp_col_.size(0) / param_.nb_group;
for (int gid = 0; gid < param_.nb_group; ++gid) {
for (uint32_t gid = 0; gid < param_.nb_group; ++gid) {
mshadow::Tensor<xpu, 2> tmpc = temp_col_.Slice(gstride * gid,
gstride * (gid + 1));
temp_dst_[gid] = dot(wmat[gid], tmpc);
Expand Down Expand Up @@ -148,9 +148,9 @@ class ConvolutionOp : public Operator {
const index_t nbatch = data.size(0);
for (index_t i = 0; i < nbatch; i += param_.nstep) {
const index_t step = std::min(param_.nstep, nbatch - i);
temp_col_.Resize(mshadow::Shape2(shape_colunit_[0],
temp_col_.Resize(Shape2(shape_colunit_[0],
shape_colunit_[1] * step));
temp_dst_.Resize(mshadow::Shape3(shape_dstunit_[0],
temp_dst_.Resize(Shape3(shape_dstunit_[0],
shape_dstunit_[1], shape_dstunit_[2] * step));
temp_dst_ = reshape(swapaxis<1, 0>(grad.Slice(i, i + step)), temp_dst_.shape_);
if (param_.pad[0] == 0 && param_.pad[1] == 0) {
Expand All @@ -167,13 +167,18 @@ class ConvolutionOp : public Operator {
param_.stride[0]);
}
const index_t gstride = temp_col_.size(0) / param_.nb_group;
for (int gid = 0; gid < param_.nb_group; ++gid) {
mshadow::Tensor<xpu, 2> tmpc = temp_col_.Slice(gstride * gid, gstride * (gid + 1));
gwmat[gid] += dot(temp_dst_[gid], tmpc.T());
for (uint32_t gid = 0; gid < param_.nb_group; ++gid) {
Tensor<xpu, 2> tmpc = temp_col_.Slice(gstride * gid, gstride * (gid + 1));
if (i == 0) {
Tensor<xpu, 2> tmp_gwmat = gwmat[gid];
Assign(tmp_gwmat, req[kWeight], dot(temp_dst_[gid], tmpc.T()));
} else {
gwmat[gid] += dot(temp_dst_[gid], tmpc.T());
}
}
if (req[kData] == kWriteTo) {
for (int gid = 0; gid < param_.nb_group; ++gid) {
mshadow::Tensor<xpu, 2> tmpc = temp_col_.Slice(gstride * gid, gstride * (gid + 1));
for (uint32_t gid = 0; gid < param_.nb_group; ++gid) {
Tensor<xpu, 2> tmpc = temp_col_.Slice(gstride * gid, gstride * (gid + 1));
tmpc = dot(wmat[gid].T(), temp_dst_[gid]);
}
if (param_.pad[0] == 0 && param_.pad[1] == 0) {
Expand All @@ -183,7 +188,7 @@ class ConvolutionOp : public Operator {
param_.kernel[1],
param_.stride[0]);
} else {
mshadow::Shape<4> pshape = data.Slice(i, i + step).shape_;
Shape<4> pshape = data.Slice(i, i + step).shape_;
pshape[2] += 2 * param_.pad[0];
pshape[3] += 2 * param_.pad[1];
gdata.Slice(i, i + step) = crop(pack_col2patch(temp_col_,
Expand All @@ -197,8 +202,7 @@ class ConvolutionOp : public Operator {
}
if (!param_.no_bias) {
Tensor<xpu, 1> gbias = in_grad[kBias].get<xpu, 1, real_t>(s);
// Assign(gbias, req[kBias], sumall_except_dim<1>(grad);
gbias += sumall_except_dim<1>(grad);
Assign(gbias, req[kBias], sumall_except_dim<1>(grad));
}
}

Expand Down
15 changes: 9 additions & 6 deletions src/operator/pooling-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ class PoolingOp : public Operator {
const std::vector<TBlob> &out_data) {
using namespace mshadow;
using namespace mshadow::expr;
CHECK_EQ(req[kOut], kWriteTo);
CHECK_EQ(in_data.size(), 1);
CHECK_EQ(out_data.size(), 1);
Stream<xpu> *s = ctx.get_stream<xpu>();
Expand All @@ -75,18 +74,22 @@ class PoolingOp : public Operator {
mshadow::Shape<2> out_shape = Shape2(out.shape_[2], out.shape_[3]);
// TODO(bing): dual stride in mshadow
if (param_.pool_type == kMaxPooling || param_.pool_type == kSumPooling) {
out = pool<Reducer>(pad(data, param_.pad[0], param_.pad[1]),
Assign(out,
req[kOut],
pool<Reducer>(pad(data, param_.pad[0], param_.pad[1]),
out_shape,
param_.kernel[0],
param_.kernel[1],
param_.kernel[0]);
param_.kernel[0]));
} else if (param_.pool_type == kAvgPooling) {
out = (1.0f / (param_.kernel[0] * param_.kernel[1])) * \
pool<Reducer>(pad(data, param_.pad[0], param_.pad[1]),
Assign(out,
req[kOut],
(1.0f / (param_.kernel[0] * param_.kernel[1])) * \
pool<Reducer>(pad(data, param_.pad[0], param_.pad[1]),
out_shape,
param_.kernel[0],
param_.kernel[1],
param_.kernel[0]);
param_.kernel[0]));
}
}

Expand Down
1 change: 1 addition & 0 deletions src/operator/reshape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ MXNET_REGISTER_OP_PROPERTY(Reshape, ReshapeProp)

MXNET_REGISTER_OP_PROPERTY(Flatten, FlattenProp)
.add_argument("data", "Symbol", "Input data to flatten.")
.add_arguments(ReshapeParam::__FIELDS__())
.describe("Flatten input");
} // namespace op
} // namespace mxnet
2 changes: 1 addition & 1 deletion src/operator/softmax-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class SoftmaxOp : public Operator {
Stream<xpu> *s = ctx.get_stream<xpu>();
Tensor<xpu, 2> data = in_data[kData].FlatTo2D<xpu, real_t>(s);
Tensor<xpu, 2> out = out_data[kOut].FlatTo2D<xpu, real_t>(s);
Softmax(data, out);
Softmax(out, data);
}

virtual void Backward(const OpContext &ctx,
Expand Down
104 changes: 61 additions & 43 deletions python/test_mnist.py → tests/python/test_conv.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,29 @@
# pylint: skip-file
import mxnet as mx
import numpy as np
import os, cPickle, gzip
import os, pickle, gzip
import sys


def CalAcc(out, label):
pred = np.argmax(out, axis=1)
return np.sum(pred == label) * 1.0 / out.shape[0]

def IgnorePython3():
if sys.version_info[0] >= 3:
# TODO(tianjun): use IO instead of pickle
# Python3 pickle is not able to load data correctly
sys.exit(0)


# load data
class MNISTIter(object):
def __init__(self, which_set, batch_size=100, flatten=True):
if not os.path.exists('mnist.pkl.gz'):
os.system("wget http://deeplearning.net/data/mnist/mnist.pkl.gz")
f = gzip.open('mnist.pkl.gz', 'rb')
train_set, valid_set, test_set = cPickle.load(f)
IgnorePython3()
train_set, valid_set, test_set = pickle.load(f)
f.close()
if which_set == 'train':
self.data = train_set[0]
Expand Down Expand Up @@ -55,10 +64,16 @@ def Get(self):
# symbol net
batch_size = 100
data = mx.symbol.Variable('data')
fc1 = mx.symbol.Convolution(data = data, name='conv1', nb_filter=32, kernel=(7,7), stride=(2,2), nstep=10, no_bias=1)
act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu")
mp = mx.symbol.Pooling(data = act1, name = 'mp', kernel=(2,2), stride=(2,2), pool_type='avg')
fl = mx.symbol.Flatten(data = mp, name="flatten")
conv1= mx.symbol.Convolution(data = data, name='conv1', nb_filter=32, kernel=(3,3), stride=(1,1), nstep=10)
act1 = mx.symbol.Activation(data = conv1, name='relu1', act_type="relu")
mp1 = mx.symbol.Pooling(data = act1, name = 'mp1', kernel=(2,2), stride=(2,2), pool_type='max')

conv2= mx.symbol.Convolution(data = mp1, name='conv2', nb_filter=32, kernel=(3,3), stride=(1,1), nstep=10)
act2 = mx.symbol.Activation(data = conv2, name='relu2', act_type="relu")
mp2 = mx.symbol.Pooling(data = act2, name = 'mp2', kernel=(2,2), stride=(2,2), pool_type='max')


fl = mx.symbol.Flatten(data = mp2, name="flatten")
fc2 = mx.symbol.FullyConnected(data = fl, name='fc2', num_hidden=10)
softmax = mx.symbol.Softmax(data = fc2, name = 'sm')
args_list = softmax.list_arguments()
Expand All @@ -69,14 +84,12 @@ def Get(self):
arg_shapes, out_shapes = softmax.infer_shape(data=data_shape)
arg_narrays = [mx.narray.create(shape) for shape in arg_shapes]
grad_narrays = [mx.narray.create(shape) for shape in arg_shapes]
mom_narrays = [mx.narray.create(shape) for shape in arg_shapes]
inputs = dict(zip(args_list, arg_narrays))
print zip(args_list, arg_shapes)
np.random.seed(0)
# set random weight
for name, narray in inputs.items():
if "weight" in name:
narray.numpy[:, :] = np.random.uniform(-0.001, 0.001, narray.numpy.shape)
narray.numpy[:, :] = np.random.uniform(-0.07, 0.07, narray.numpy.shape)
if "bias" in name:
narray.numpy[:] = 0.0

Expand All @@ -89,47 +102,52 @@ def Get(self):
out_narray = executor.heads()[0]
grad_narray = mx.narray.create(out_narray.shape)

epoch = 10
epoch = 1
momentum = 0.9
lr = 0.001
lr = 0.1
wd = 0.0004

def Update(mom, grad, weight):
weight.numpy[:] -= lr * grad.numpy[:]
def Update(grad, weight):
weight.numpy[:] -= lr * grad.numpy[:] / batch_size

block = zip(mom_narrays, grad_narrays, arg_narrays)
block = zip(grad_narrays, arg_narrays)


train = MNISTIter("train", batch_size, False)
valid = MNISTIter("valid", batch_size, False)

for i in xrange(epoch):
# train
print "Epoch %d" % i
train_acc = 0.0
val_acc = 0.0
while train.Next():
data, label = train.Get()
inputs["data"].numpy[:] = data
inputs["sm_label"].numpy[:] = label
executor.forward()
train_acc += CalAcc(out_narray.numpy, label)
grad_narray.numpy[:] = out_narray.numpy
executor.backward([grad_narray])

for mom, grad, weight in block:
Update(mom, grad, weight)

# evaluate
while valid.Next():
data, label = valid.Get()
inputs["data"].numpy[:] = data
executor.forward()
val_acc += CalAcc(out_narray.numpy, label)
print "Train Acc: ", train_acc / train.nbatch
print "Valid Acc: ", val_acc / valid.nbatch
train.BeforeFirst()
valid.BeforeFirst()


def test_mnist():
acc_train = 0.0
acc_val = 0.0
for i in xrange(epoch):
# train
print("Epoch %d" % i)
train_acc = 0.0
val_acc = 0.0
while train.Next():
data, label = train.Get()
inputs["data"].numpy[:] = data
inputs["sm_label"].numpy[:] = label
executor.forward()
train_acc += CalAcc(out_narray.numpy, label)
grad_narray.numpy[:] = out_narray.numpy
executor.backward([grad_narray])

for grad, weight in block:
Update(grad, weight)

# evaluate
while valid.Next():
data, label = valid.Get()
inputs["data"].numpy[:] = data
executor.forward()
val_acc += CalAcc(out_narray.numpy, label)
print("Train Acc: ", train_acc / train.nbatch)
print("Valid Acc: ", val_acc / valid.nbatch)
acc_train = train_acc / train.nbatch
acc_val = val_acc / valid.nbatch
train.BeforeFirst()
valid.BeforeFirst()
assert(acc_train > 0.84)
assert(acc_val > 0.96)

Loading

0 comments on commit 2ad67a3

Please sign in to comment.