Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

CuDNN support #58

Merged
merged 4 commits into from
Sep 11, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ ifeq ($(USE_OPENMP_ITER), 1)
endif

ifeq ($(USE_CUDNN), 1)
CFLAGS += -DCXXNET_USE_CUDNN=1
CFLAGS += -DMSHADOW_USE_CUDNN=1
LDFLAGS += -lcudnn
endif

Expand Down
63 changes: 27 additions & 36 deletions example/cifar10/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import sys
sys.path.append("../../tests/python")
import get_data

import time

"""
CXXNET Result:
Expand Down Expand Up @@ -70,8 +70,8 @@ def ConvFactory(**kwargs):
param = copy.copy(kwargs)
act = param["act_type"]
del param["act_type"]
param["workspace"] = 256
param["name"] = "conv%d" % conv_cnt
param["nstep"] = 64
conv = mx.symbol.Convolution(**param)
bn = mx.symbol.BatchNorm(data = conv, name="bn%d" % conv_cnt)
relu = mx.symbol.Activation(data = bn, name = "%s%d" % (act, conv_cnt), act_type=act)
Expand All @@ -89,13 +89,11 @@ def DownsampleFactory(data, ch_3x3, stride = 2):
param["num_filter"] = ch_3x3
param["act_type"] = "relu"
param["data"] = data
param["nstep"] = 100
param["pad"] = (1, 1)
conv3x3 = ConvFactory(**param)
# pool
del param["num_filter"]
del param["act_type"]
del param["nstep"]
del param["pad"]
param["pool_type"] = "max"
param["name"] = "pool%d" % pool_cnt
Expand All @@ -117,7 +115,6 @@ def SimpleFactory(data, ch_1x1, ch_3x3):
param["stride"] = (1, 1)
param["act_type"] = "relu"
param["data"] = data
param["nstep"] = 128
conv1x1 = ConvFactory(**param)

# 3x3
Expand All @@ -143,7 +140,7 @@ def RandomInit(narray):
in3a = SimpleFactory(conv1, 32, 32)
in3b = SimpleFactory(in3a, 32, 48)
in3c = DownsampleFactory(in3b, 80)
in4a = SimpleFactory(in3c, 112, 38)
in4a = SimpleFactory(in3c, 112, 48)
in4b = SimpleFactory(in4a, 96, 64)
in4c = SimpleFactory(in4b, 80, 80)
in4d = SimpleFactory(in4c, 48, 96)
Expand All @@ -155,22 +152,26 @@ def RandomInit(narray):
fc = mx.symbol.FullyConnected(data=flatten, num_hidden=10, name="fc1")
loss = mx.symbol.Softmax(data=fc, name="sm")

args_list = loss.list_arguments()

epoch = 9
lr = 0.05
wd = 0.0001
momentum = 0.9

batch_size = 128
data_shape = (batch_size, 3, 28, 28)
arg_shapes, out_shapes, aux_shapes = loss.infer_shape(data=data_shape)

arg_narrays = [mx.narray.zeros(shape, ctx=mx.Context("gpu")) for shape in arg_shapes]
grad_narrays = [mx.narray.zeros(shape, ctx=mx.Context("gpu")) for shape in arg_shapes]
mom_narrays = [mx.narray.zeros(shape, ctx=mx.Context("gpu")) for shape in arg_shapes]
aux_narrays = [mx.narray.zeros(shape, ctx=mx.Context("gpu")) for shape in aux_shapes]
in_data = mx.narray.empty(data_shape, mx.gpu())
executor = loss.simple_bind(mx.gpu(), {"data": in_data})
out_narray = executor.heads()[0]
pred = mx.narray.zeros(out_narray.shape, mx.cpu())

inputs = dict(zip(args_list, arg_narrays))
arg_narrays, grad_narrays = executor.list_arguments()
inputs = dict(zip(loss.list_arguments(), arg_narrays))
tmp_label = mx.narray.zeros(inputs["sm_label"].shape)
momentum_narrays = [mx.narray.zeros(item.shape, mx.gpu()) for item in grad_narrays]

name2shape = dict(zip(args_list, arg_shapes))
pred = mx.narray.zeros(out_shapes[0])
block = list(zip(grad_narrays, arg_narrays, momentum_narrays))

np.random.seed(0)
# set random weight
Expand All @@ -185,25 +186,11 @@ def RandomInit(narray):
if "beta" in name:
narray[:] = 0.0

# bind executer
# TODO(bing): think of a better bind interface
executor = loss.bind(mx.Context('gpu'), arg_narrays, grad_narrays, 'write', aux_narrays)
# update

out_narray = executor.heads()[0]

epoch = 9
lr = 0.05
wd = 0.0001
momentum = 0.9

def Update(grad, weight, mom):
mom[:] *= momentum
mom[:] += -lr * (grad / batch_size + wd * weight)
weight[:] += mom

block = list(zip(grad_narrays, arg_narrays, mom_narrays))

#check data
get_data.GetCifar10()

Expand All @@ -224,17 +211,19 @@ def Update(grad, weight, mom):
batch_size=batch_size,
nthread=1)

tmp_label = mx.narray.zeros(name2shape["sm_label"])

def progress(count, total, suffix=''):
bar_len = 80
def progress(count, total, epoch, toc):
bar_len = 50
filled_len = int(round(bar_len * count / float(total)))

percents = round(100.0 * count / float(total), 1)
bar = '=' * filled_len + '-' * (bar_len - filled_len)

tic = time.time()
speed = batch_size / float(tic - toc)
suffix = "Epoch %d, Speed: %.2f pic/sec" % (epoch, speed)
sys.stdout.write('[%s] %s%s ...%s\r' % (bar, percents, '%', suffix))


def test_cifar():
acc_train = 0.
acc_val = 0.
Expand All @@ -245,9 +234,9 @@ def test_cifar():
val_acc = 0.0
train_nbatch = 0
val_nbatch = 0
all_train_bacth = 50000 / float(batch_size)
all_train_bacth = round(50000 / float(batch_size) + 1)
for data, label in train_dataiter:
progress(train_nbatch, all_train_bacth, "Epoch %d" % i)
toc = time.time()
label = label.asnumpy().flatten()
tmp_label[:] = label
inputs["data"][:] = data
Expand All @@ -256,10 +245,12 @@ def test_cifar():
pred[:] = out_narray
train_acc += CalAcc(pred.asnumpy(), label)
train_nbatch += 1
executor.backward([out_narray])
#executor.backward([out_narray])
executor.backward()

for grad, weight, mom in block:
Update(grad, weight, mom)
progress(train_nbatch, all_train_bacth, i, toc)

# evaluate
for data, label in test_dataiter:
Expand Down
2 changes: 1 addition & 1 deletion include/mxnet/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
*\brief whether to use cudnn library for convolution
*/
#ifndef MXNET_USE_CUDNN
#define MXNET_USE_CUDNN 0
#define MXNET_USE_CUDNN MSHADOW_USE_CUDNN
#endif

/*! \brief namespace of mxnet */
Expand Down
5 changes: 4 additions & 1 deletion include/mxnet/symbolic.h
Original file line number Diff line number Diff line change
Expand Up @@ -394,9 +394,12 @@ class Executor {
* \brief Perform a Backward operation of the Operator.
* This must be called after Forward.
* After this operation, NArrays specified by grad_in_args_store will be updated accordingly.
* User is allowed to pass in an empty Array if the head node is
* loss function and head gradeitn is not needed.
*
* \param head_grads the gradient of head nodes to be backproped.
*/
virtual void Backward(const std::vector<NArray> &head_grads) = 0;
virtual void Backward(const std::vector<NArray> &head_grads = {}) = 0;
/*!
* \brief get array of heads in the executor.
* \return array of heads in the executor.
Expand Down
2 changes: 1 addition & 1 deletion mshadow
2 changes: 1 addition & 1 deletion python/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""
from __future__ import absolute_import

from .context import Context, current_context
from .context import Context, current_context, cpu, gpu
from .base import MXNetError
from . import narray
from . import symbol
Expand Down
38 changes: 38 additions & 0 deletions python/mxnet/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,44 @@ def __exit__(self, ptype, value, trace):
# initialize the default context in Context
Context.default_ctx = Context('cpu', 0)


def cpu(device_id=0):
"""Return a CPU context.

This function is a short cut for Context('cpu', device_id)

Parameters
----------
device_id : int, optional
The device id of the device. device_id is not needed for CPU.
This is included to make interface compatible with GPU.

Returns
-------
context : Context
The corresponding CPU context.
"""
return Context('cpu', device_id)


def gpu(device_id=0):
"""Return a GPU context.

This function is a short cut for Context('cpu', device_id)

Parameters
----------
device_id : int, optional
The device id of the device, needed for GPU

Returns
-------
context : Context
The corresponding GPU context.
"""
return Context('gpu', device_id)


def current_context():
"""Return the current context.

Expand Down
45 changes: 39 additions & 6 deletions python/mxnet/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,34 @@ def __init__(self, handle):
if not isinstance(handle, ExecutorHandle):
raise TypeError("Handle type error")
self.handle = handle
self.arg_narrays = []
self.grad_narrays = []
self.auxiliary_states = []

def list_arguments(self, with_grad=True):
"""Return arguments (and grad for arguments)

Parameters
----------
with_grad: bool
whether return args with grad

Returns
-------
if with_grad = True, return (args, grad) pair list
otherwise return args list only
Note: args sequence is same to symbol.list_arguments()
"""
if with_grad:
return self.arg_narrays, self.grad_narrays
else:
return self.arg_narrays

def list_auxiliary_states(self):
"""Return auxiliary states of executor
Note: auxiliary states is same to symbol.list_auxiliary_states()
"""
return self.auxiliary_states

def forward(self, is_train=True):
"""Do forward.
Expand All @@ -34,19 +62,24 @@ def forward(self, is_train=True):
"""
check_call(_LIB.MXExecutorForward(self.handle, is_train))

def backward(self, grads):
def backward(self, head_grads=None):
"""Do backward on heads' gradient.

Parameters
----------
grads: Array of NArray
heads' gradient
head_grads : NArray or list of NArray, optional
Gradient on the heads
"""
for obj in grads:
if head_grads is None:
head_grads = []
elif isinstance(head_grads, NArray):
head_grads = [head_grads]

for obj in head_grads:
if not isinstance(obj, NArray):
raise TypeError("inputs must be NArray")
narray = c_array(NArrayHandle, [item.handle for item in grads])
check_call(_LIB.MXExecutorBackward(self.handle, len(grads), narray))
narray = c_array(NArrayHandle, [item.handle for item in head_grads])
check_call(_LIB.MXExecutorBackward(self.handle, len(head_grads), narray))

def heads(self):
"""list all heads' output narray
Expand Down
10 changes: 2 additions & 8 deletions python/mxnet/narray.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,9 +349,7 @@ def zeros(shape, ctx=None):
out: Array
The created NArray.
"""
if ctx is None:
ctx = Context.default_ctx
arr = NArray(handle=_new_alloc_handle(shape, ctx, False))
arr = empty(shape, ctx)
arr[:] = 0.0
return arr

Expand All @@ -371,15 +369,11 @@ def ones(shape, ctx=None):
out: Array
The created NArray.
"""
if ctx is None:
ctx = Context.default_ctx
arr = NArray(handle=_new_alloc_handle(shape, ctx, False))
arr = empty(shape, ctx)
arr[:] = 1.0
return arr




def array(source_array, ctx=None):
"""Create a new NArray that copies content from source_array.

Expand Down
Loading