diff --git a/Makefile b/Makefile index 880ea7a73ab8..e5fc24acc97a 100644 --- a/Makefile +++ b/Makefile @@ -66,7 +66,7 @@ ifeq ($(USE_OPENMP_ITER), 1) endif ifeq ($(USE_CUDNN), 1) - CFLAGS += -DCXXNET_USE_CUDNN=1 + CFLAGS += -DMSHADOW_USE_CUDNN=1 LDFLAGS += -lcudnn endif diff --git a/example/cifar10/cifar10.py b/example/cifar10/cifar10.py index 10a7c40eea03..892f6fff5d1d 100644 --- a/example/cifar10/cifar10.py +++ b/example/cifar10/cifar10.py @@ -5,7 +5,7 @@ import sys sys.path.append("../../tests/python") import get_data - +import time """ CXXNET Result: @@ -70,8 +70,8 @@ def ConvFactory(**kwargs): param = copy.copy(kwargs) act = param["act_type"] del param["act_type"] + param["workspace"] = 256 param["name"] = "conv%d" % conv_cnt - param["nstep"] = 64 conv = mx.symbol.Convolution(**param) bn = mx.symbol.BatchNorm(data = conv, name="bn%d" % conv_cnt) relu = mx.symbol.Activation(data = bn, name = "%s%d" % (act, conv_cnt), act_type=act) @@ -89,13 +89,11 @@ def DownsampleFactory(data, ch_3x3, stride = 2): param["num_filter"] = ch_3x3 param["act_type"] = "relu" param["data"] = data - param["nstep"] = 100 param["pad"] = (1, 1) conv3x3 = ConvFactory(**param) # pool del param["num_filter"] del param["act_type"] - del param["nstep"] del param["pad"] param["pool_type"] = "max" param["name"] = "pool%d" % pool_cnt @@ -117,7 +115,6 @@ def SimpleFactory(data, ch_1x1, ch_3x3): param["stride"] = (1, 1) param["act_type"] = "relu" param["data"] = data - param["nstep"] = 128 conv1x1 = ConvFactory(**param) # 3x3 @@ -143,7 +140,7 @@ def RandomInit(narray): in3a = SimpleFactory(conv1, 32, 32) in3b = SimpleFactory(in3a, 32, 48) in3c = DownsampleFactory(in3b, 80) -in4a = SimpleFactory(in3c, 112, 38) +in4a = SimpleFactory(in3c, 112, 48) in4b = SimpleFactory(in4a, 96, 64) in4c = SimpleFactory(in4b, 80, 80) in4d = SimpleFactory(in4c, 48, 96) @@ -155,22 +152,26 @@ def RandomInit(narray): fc = mx.symbol.FullyConnected(data=flatten, num_hidden=10, name="fc1") loss = mx.symbol.Softmax(data=fc, name="sm") -args_list = loss.list_arguments() +epoch = 9 +lr = 0.05 +wd = 0.0001 +momentum = 0.9 batch_size = 128 data_shape = (batch_size, 3, 28, 28) -arg_shapes, out_shapes, aux_shapes = loss.infer_shape(data=data_shape) -arg_narrays = [mx.narray.zeros(shape, ctx=mx.Context("gpu")) for shape in arg_shapes] -grad_narrays = [mx.narray.zeros(shape, ctx=mx.Context("gpu")) for shape in arg_shapes] -mom_narrays = [mx.narray.zeros(shape, ctx=mx.Context("gpu")) for shape in arg_shapes] -aux_narrays = [mx.narray.zeros(shape, ctx=mx.Context("gpu")) for shape in aux_shapes] +in_data = mx.narray.empty(data_shape, mx.gpu()) +executor = loss.simple_bind(mx.gpu(), {"data": in_data}) +out_narray = executor.heads()[0] +pred = mx.narray.zeros(out_narray.shape, mx.cpu()) -inputs = dict(zip(args_list, arg_narrays)) +arg_narrays, grad_narrays = executor.list_arguments() +inputs = dict(zip(loss.list_arguments(), arg_narrays)) +tmp_label = mx.narray.zeros(inputs["sm_label"].shape) +momentum_narrays = [mx.narray.zeros(item.shape, mx.gpu()) for item in grad_narrays] -name2shape = dict(zip(args_list, arg_shapes)) -pred = mx.narray.zeros(out_shapes[0]) +block = list(zip(grad_narrays, arg_narrays, momentum_narrays)) np.random.seed(0) # set random weight @@ -185,25 +186,11 @@ def RandomInit(narray): if "beta" in name: narray[:] = 0.0 -# bind executer -# TODO(bing): think of a better bind interface -executor = loss.bind(mx.Context('gpu'), arg_narrays, grad_narrays, 'write', aux_narrays) -# update - -out_narray = executor.heads()[0] - -epoch = 9 -lr = 0.05 -wd = 0.0001 -momentum = 0.9 - def Update(grad, weight, mom): mom[:] *= momentum mom[:] += -lr * (grad / batch_size + wd * weight) weight[:] += mom -block = list(zip(grad_narrays, arg_narrays, mom_narrays)) - #check data get_data.GetCifar10() @@ -224,17 +211,19 @@ def Update(grad, weight, mom): batch_size=batch_size, nthread=1) -tmp_label = mx.narray.zeros(name2shape["sm_label"]) -def progress(count, total, suffix=''): - bar_len = 80 +def progress(count, total, epoch, toc): + bar_len = 50 filled_len = int(round(bar_len * count / float(total))) percents = round(100.0 * count / float(total), 1) bar = '=' * filled_len + '-' * (bar_len - filled_len) - + tic = time.time() + speed = batch_size / float(tic - toc) + suffix = "Epoch %d, Speed: %.2f pic/sec" % (epoch, speed) sys.stdout.write('[%s] %s%s ...%s\r' % (bar, percents, '%', suffix)) + def test_cifar(): acc_train = 0. acc_val = 0. @@ -245,9 +234,9 @@ def test_cifar(): val_acc = 0.0 train_nbatch = 0 val_nbatch = 0 - all_train_bacth = 50000 / float(batch_size) + all_train_bacth = round(50000 / float(batch_size) + 1) for data, label in train_dataiter: - progress(train_nbatch, all_train_bacth, "Epoch %d" % i) + toc = time.time() label = label.asnumpy().flatten() tmp_label[:] = label inputs["data"][:] = data @@ -256,10 +245,12 @@ def test_cifar(): pred[:] = out_narray train_acc += CalAcc(pred.asnumpy(), label) train_nbatch += 1 - executor.backward([out_narray]) + #executor.backward([out_narray]) + executor.backward() for grad, weight, mom in block: Update(grad, weight, mom) + progress(train_nbatch, all_train_bacth, i, toc) # evaluate for data, label in test_dataiter: diff --git a/include/mxnet/base.h b/include/mxnet/base.h index a7a3a8063a92..e3fbe002fdfc 100644 --- a/include/mxnet/base.h +++ b/include/mxnet/base.h @@ -30,7 +30,7 @@ *\brief whether to use cudnn library for convolution */ #ifndef MXNET_USE_CUDNN -#define MXNET_USE_CUDNN 0 +#define MXNET_USE_CUDNN MSHADOW_USE_CUDNN #endif /*! \brief namespace of mxnet */ diff --git a/include/mxnet/symbolic.h b/include/mxnet/symbolic.h index 97eed74a53be..ef9e562f64bc 100644 --- a/include/mxnet/symbolic.h +++ b/include/mxnet/symbolic.h @@ -394,9 +394,12 @@ class Executor { * \brief Perform a Backward operation of the Operator. * This must be called after Forward. * After this operation, NArrays specified by grad_in_args_store will be updated accordingly. + * User is allowed to pass in an empty Array if the head node is + * loss function and head gradeitn is not needed. + * * \param head_grads the gradient of head nodes to be backproped. */ - virtual void Backward(const std::vector &head_grads) = 0; + virtual void Backward(const std::vector &head_grads = {}) = 0; /*! * \brief get array of heads in the executor. * \return array of heads in the executor. diff --git a/mshadow b/mshadow index 3053f8cdfea0..208a198213ea 160000 --- a/mshadow +++ b/mshadow @@ -1 +1 @@ -Subproject commit 3053f8cdfea0274739282ced015ad458090760e8 +Subproject commit 208a198213ea011e42f91b128b14a7206cce62a5 diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py index 78b986e9c6e0..7c4246f25285 100644 --- a/python/mxnet/__init__.py +++ b/python/mxnet/__init__.py @@ -8,7 +8,7 @@ """ from __future__ import absolute_import -from .context import Context, current_context +from .context import Context, current_context, cpu, gpu from .base import MXNetError from . import narray from . import symbol diff --git a/python/mxnet/context.py b/python/mxnet/context.py index 7a043bf5c4b9..707591f9fc19 100644 --- a/python/mxnet/context.py +++ b/python/mxnet/context.py @@ -52,6 +52,44 @@ def __exit__(self, ptype, value, trace): # initialize the default context in Context Context.default_ctx = Context('cpu', 0) + +def cpu(device_id=0): + """Return a CPU context. + + This function is a short cut for Context('cpu', device_id) + + Parameters + ---------- + device_id : int, optional + The device id of the device. device_id is not needed for CPU. + This is included to make interface compatible with GPU. + + Returns + ------- + context : Context + The corresponding CPU context. + """ + return Context('cpu', device_id) + + +def gpu(device_id=0): + """Return a GPU context. + + This function is a short cut for Context('cpu', device_id) + + Parameters + ---------- + device_id : int, optional + The device id of the device, needed for GPU + + Returns + ------- + context : Context + The corresponding GPU context. + """ + return Context('gpu', device_id) + + def current_context(): """Return the current context. diff --git a/python/mxnet/executor.py b/python/mxnet/executor.py index 17df30190ce8..a3ba09ca1a76 100644 --- a/python/mxnet/executor.py +++ b/python/mxnet/executor.py @@ -22,6 +22,34 @@ def __init__(self, handle): if not isinstance(handle, ExecutorHandle): raise TypeError("Handle type error") self.handle = handle + self.arg_narrays = [] + self.grad_narrays = [] + self.auxiliary_states = [] + + def list_arguments(self, with_grad=True): + """Return arguments (and grad for arguments) + + Parameters + ---------- + with_grad: bool + whether return args with grad + + Returns + ------- + if with_grad = True, return (args, grad) pair list + otherwise return args list only + Note: args sequence is same to symbol.list_arguments() + """ + if with_grad: + return self.arg_narrays, self.grad_narrays + else: + return self.arg_narrays + + def list_auxiliary_states(self): + """Return auxiliary states of executor + Note: auxiliary states is same to symbol.list_auxiliary_states() + """ + return self.auxiliary_states def forward(self, is_train=True): """Do forward. @@ -34,19 +62,24 @@ def forward(self, is_train=True): """ check_call(_LIB.MXExecutorForward(self.handle, is_train)) - def backward(self, grads): + def backward(self, head_grads=None): """Do backward on heads' gradient. Parameters ---------- - grads: Array of NArray - heads' gradient + head_grads : NArray or list of NArray, optional + Gradient on the heads """ - for obj in grads: + if head_grads is None: + head_grads = [] + elif isinstance(head_grads, NArray): + head_grads = [head_grads] + + for obj in head_grads: if not isinstance(obj, NArray): raise TypeError("inputs must be NArray") - narray = c_array(NArrayHandle, [item.handle for item in grads]) - check_call(_LIB.MXExecutorBackward(self.handle, len(grads), narray)) + narray = c_array(NArrayHandle, [item.handle for item in head_grads]) + check_call(_LIB.MXExecutorBackward(self.handle, len(head_grads), narray)) def heads(self): """list all heads' output narray diff --git a/python/mxnet/narray.py b/python/mxnet/narray.py index acc05d08d546..208fd8e17d7a 100644 --- a/python/mxnet/narray.py +++ b/python/mxnet/narray.py @@ -349,9 +349,7 @@ def zeros(shape, ctx=None): out: Array The created NArray. """ - if ctx is None: - ctx = Context.default_ctx - arr = NArray(handle=_new_alloc_handle(shape, ctx, False)) + arr = empty(shape, ctx) arr[:] = 0.0 return arr @@ -371,15 +369,11 @@ def ones(shape, ctx=None): out: Array The created NArray. """ - if ctx is None: - ctx = Context.default_ctx - arr = NArray(handle=_new_alloc_handle(shape, ctx, False)) + arr = empty(shape, ctx) arr[:] = 1.0 return arr - - def array(source_array, ctx=None): """Create a new NArray that copies content from source_array. diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index f882933538b2..90df0b663615 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -10,7 +10,7 @@ from .base import NArrayHandle, ExecutorHandle, SymbolHandle from .base import check_call from .context import Context -from .narray import NArray +from .narray import NArray, zeros from .executor import Executor @@ -332,6 +332,45 @@ def _get_narray_handle(arg_key, args, arg_names, allow_missing): raise TypeError('Only Accept list of NArrays or dict of str->NArray') return c_array(NArrayHandle, arg_handles) + def simple_bind(self, ctx, args, grad_req='write'): + """Simply bind current symbol to get an executor + Parameters + ---------- + ctx : Context + The device context the generated executor to run on. + + args : list of NArray or dict of str->NArray + Input arguments to the symbol. + - type is dict of str->NArray, then it maps the name of arguments + to the corresponding NArray, + - Not all the arguments must be provided. + Returns + ------- + executor : mxnet.Executor + The generated Executor + """ + if not isinstance(args, dict): + raise TypeError("args must be dict of str->NArray") + input_shapes = dict((name, arr.shape) for name, arr in args.items()) + # pylint: disable=unused-variable + arg_shapes, out_shapes, aux_shapes = self.infer_shape(**input_shapes) + # pylint: enable=unused-variable + if arg_shapes == None: + raise ValueError("Input node is not complete") + # alloc space + arg_narrays = [] + for name, shape in zip(self.list_arguments(), arg_shapes): + if name in args: + arg_narrays.append(args[name]) + else: + arg_narrays.append(zeros(shape, ctx)) + # TODO(bing): specail treat input data grad + # TODO(bing): not generate grad case + grad_narrays = [zeros(shape, ctx) for shape in arg_shapes] + aux_narrays = [zeros(shape, ctx) for shape in aux_shapes] + executor = self.bind(ctx, arg_narrays, grad_narrays, grad_req, aux_narrays) + return executor + def bind(self, ctx, args, args_grad=None, grad_req='write', aux_states=None): """Bind current symbol to get an executor. @@ -386,6 +425,7 @@ def bind(self, ctx, args, args_grad=None, grad_req='write', aux_states=None): User can give up gradient by using a dict in args_grad and only specify gradient they interested in. """ + # pylint: disable=too-many-locals if not isinstance(ctx, Context): raise TypeError("Context type error") @@ -430,7 +470,11 @@ def bind(self, ctx, args, args_grad=None, grad_req='write', aux_states=None): len(aux_states), aux_args_handle, ctypes.byref(handle))) - return Executor(handle) + executor = Executor(handle) + executor.arg_narrays = args + executor.grad_narrays = args_grad + executor.auxiliary_states = aux_states + return executor def grad(self, wrt): """Get the autodiff of current symbol. diff --git a/src/engine/naive_engine.cc b/src/engine/naive_engine.cc index ff3b0467b08e..7f38fd92de35 100644 --- a/src/engine/naive_engine.cc +++ b/src/engine/naive_engine.cc @@ -10,16 +10,21 @@ namespace engine { NaiveEngine::VarHandle NaiveEngine::NewVariable() { return nullptr; } NaiveEngine::NaiveEngine() { -#if MXNET_USE_CUDA + #if MXNET_USE_CUDA + #if MXNET_USE_CUDNN == 1 + LOG(INFO) << "MXNet is using CuDNN for Convolution, Pooling Op"; + stream_ = mshadow::NewStream(true, true); + #else stream_ = mshadow::NewStream(true, false); + #endif // MXNET_USE_CUDNN ctx_.stream = stream_; -#endif + #endif // MXNET_USE_CUDA } NaiveEngine::~NaiveEngine() { -#if MXNET_USE_CUDA + #if MXNET_USE_CUDA mshadow::DeleteStream(stream_); -#endif + #endif } NaiveEngine::OprHandle NaiveEngine::NewOperator(AsyncFn, @@ -66,5 +71,5 @@ void NaiveEngine::WaitForVar(VarHandle) {} void NaiveEngine::WaitForAll() {} } // namespace engine - } // namespace mxnet + diff --git a/src/engine/stream_manager.h b/src/engine/stream_manager.h index 75bca74935bc..76f18d29ec05 100644 --- a/src/engine/stream_manager.h +++ b/src/engine/stream_manager.h @@ -56,7 +56,11 @@ RunContext StreamManager::GetRunContext( auto&& counter = gpu_cnt_.at(ctx.dev_id); if (counter == -1) { for (auto&& i : gpu_streams_.at(ctx.dev_id)) { + #if MXNET_USE_CUDNN == 1 + i = mshadow::NewStream(true, true); + #else i = mshadow::NewStream(true, false); + #endif // MXNET_USE_CUDNN } counter = 0; } diff --git a/src/operator/activation-inl.h b/src/operator/activation-inl.h index 43aa4f01637a..2319f074cc73 100644 --- a/src/operator/activation-inl.h +++ b/src/operator/activation-inl.h @@ -117,7 +117,11 @@ class ActivationProp : public OperatorProperty { const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data) const override { + #if MXNET_USE_CUDNN == 1 + return {out_grad[kOut], out_data[kOut], in_data[kData]}; + #else return {out_grad[kOut], out_data[kOut]}; + #endif // MXNET_USE_CUDNN } std::vector > BackwardInplaceOption( diff --git a/src/operator/activation.cu b/src/operator/activation.cu index b1b8fc4fb8b0..4325d4c53a46 100644 --- a/src/operator/activation.cu +++ b/src/operator/activation.cu @@ -6,11 +6,17 @@ */ #include "./activation-inl.h" #include "./mshadow_op.h" +#if MXNET_USE_CUDNN == 1 +#include "./cudnn_activation-inl.h" +#endif namespace mxnet { namespace op { template<> Operator *CreateOp(ActivationParam param) { + #if MXNET_USE_CUDNN == 1 + return new CuDNNActivationOp(param); + #else switch(param.act_type) { case kReLU: return new ActivationOp(); case kSigmoid: return new ActivationOp(); @@ -19,6 +25,7 @@ Operator *CreateOp(ActivationParam param) { LOG(FATAL) << "unknown activation"; return NULL; } + #endif // MXNET_USE_CUDNN } } // op } // namespace mxnet diff --git a/src/operator/batch_norm-inl.h b/src/operator/batch_norm-inl.h index 0f3b303f85b6..1409615e853a 100644 --- a/src/operator/batch_norm-inl.h +++ b/src/operator/batch_norm-inl.h @@ -261,6 +261,10 @@ class BatchNormProp : public OperatorProperty { Operator* CreateOperator(Context ctx) const; + std::vector BackwardResource() const override { + return {Resource::kTempSpace}; + } + private: BatchNormParam param_; }; // class BatchNormProp diff --git a/src/operator/convolution-inl.h b/src/operator/convolution-inl.h index b313aab14f94..f8a29b204d60 100644 --- a/src/operator/convolution-inl.h +++ b/src/operator/convolution-inl.h @@ -30,7 +30,7 @@ struct ConvolutionParam : public dmlc::Parameter { TShape pad; uint32_t num_filter; uint32_t num_group; - uint32_t nstep; + uint64_t workspace; bool no_bias; DMLC_DECLARE_PARAMETER(ConvolutionParam) { int shape[] = {1, 1}; @@ -44,8 +44,8 @@ struct ConvolutionParam : public dmlc::Parameter { .describe("convolution filter(channel) number"); DMLC_DECLARE_FIELD(num_group).set_default(1) .describe("number of groups partition"); - DMLC_DECLARE_FIELD(nstep).set_default(2).set_range(1, 10000) - .describe("process n images once"); + DMLC_DECLARE_FIELD(workspace).set_default(128).set_range(1, 10000) + .describe("Tmp workspace for convolution (MB)"); DMLC_DECLARE_FIELD(no_bias).set_default(false) .describe("Whether to disable bias parameter."); } @@ -78,10 +78,14 @@ class ConvolutionOp : public Operator { TShape wmat_shape(ws, ws + 3); Tensor wmat = in_data[kWeight].get_with_shape(wmat_shape, s); Tensor out = out_data[kOut].get(s); + #if defined(__CUDACC__) + CHECK_EQ(s->blas_handle_ownership_, Stream::OwnHandle) + << "Must init CuBLAS handle in stream"; + #endif this->InitTemp(ctx, data.shape_, out.shape_); const index_t nbatch = data.size(0); - for (index_t i = 0; i < nbatch; i += param_.nstep) { - const index_t step = std::min(param_.nstep, nbatch - i); + for (index_t i = 0; i < nbatch; i += nstep_) { + const index_t step = std::min(nstep_, nbatch - i); temp_col_.Resize(mshadow::Shape2(shape_colunit_[0], shape_colunit_[1] * step)); temp_dst_.Resize(mshadow::Shape3(shape_dstunit_[0], @@ -146,10 +150,14 @@ class ConvolutionOp : public Operator { Tensor grad = out_grad[kOut].get(s); Tensor gdata = in_grad[kData].get(s); Tensor gwmat = in_grad[kWeight].get_with_shape(wmat_shape, s); + #if defined(__CUDACC__) + CHECK_EQ(s->blas_handle_ownership_, Stream::OwnHandle) + << "Must init CuBLAS handle in stream"; + #endif this->InitTemp(ctx, data.shape_, grad.shape_); const index_t nbatch = data.size(0); - for (index_t i = 0; i < nbatch; i += param_.nstep) { - const index_t step = std::min(param_.nstep, nbatch - i); + for (index_t i = 0; i < nbatch; i += nstep_) { + const index_t step = std::min(nstep_, nbatch - i); temp_col_.Resize(Shape2(shape_colunit_[0], shape_colunit_[1] * step)); temp_dst_.Resize(Shape3(shape_dstunit_[0], @@ -220,16 +228,19 @@ class ConvolutionOp : public Operator { shape_dstunit_ = mshadow::Shape3(param_.num_group, param_.num_filter / param_.num_group, oshape[2] * oshape[3]); - int nop = (ishape[0] + param_.nstep - 1) / param_.nstep; - param_.nstep = (ishape[0] + nop - 1) / nop; + const uint64_t workspace_size = param_.workspace << 20; + nstep_ = std::max(std::min(static_cast(workspace_size / shape_colunit_.Size()), + ishape[0]), 1U); + int nop = (ishape[0] + nstep_ - 1) / nstep_; + nstep_ = (ishape[0] + nop - 1) / nop; mshadow::Stream *s = ctx.get_stream(); temp_col_.set_stream(s); temp_dst_.set_stream(s); temp_col_.Resize(mshadow::Shape2(shape_colunit_[0], - shape_colunit_[1] * param_.nstep)); + shape_colunit_[1] * nstep_)); temp_dst_.Resize(mshadow::Shape3(shape_dstunit_[0], shape_dstunit_[1], - shape_dstunit_[2] * param_.nstep)); + shape_dstunit_[2] * nstep_)); } ConvolutionParam param_; @@ -238,6 +249,7 @@ class ConvolutionOp : public Operator { mshadow::TensorContainer temp_dst_; mshadow::Shape<2> shape_colunit_; mshadow::Shape<3> shape_dstunit_; + index_t nstep_; }; // class ConvolutionOp template @@ -328,6 +340,14 @@ class ConvolutionProp : public OperatorProperty { Operator* CreateOperator(Context ctx) const; + std::vector ForwardResource() const override { + return {Resource::kTempSpace}; + } + + std::vector BackwardResource() const override { + return {Resource::kTempSpace}; + } + private: ConvolutionParam param_; }; // class ConvolutionProp diff --git a/src/operator/convolution.cu b/src/operator/convolution.cu index 4f0a3ce78b45..8127cec43fd6 100644 --- a/src/operator/convolution.cu +++ b/src/operator/convolution.cu @@ -6,12 +6,19 @@ */ #include "./convolution-inl.h" +#if MXNET_USE_CUDNN == 1 +#include "./cudnn_convolution-inl.h" +#endif // MXNET_USE_CUDNN namespace mxnet { namespace op { template<> Operator* CreateOp(ConvolutionParam param) { + #if MXNET_USE_CUDNN == 1 + return new CuDNNConvolutionOp(param); + #else return new ConvolutionOp(param); + #endif // MXNET_USE_CUDNN } } // namespace op diff --git a/src/operator/cudnn_activation-inl.h b/src/operator/cudnn_activation-inl.h new file mode 100644 index 000000000000..1158a1324128 --- /dev/null +++ b/src/operator/cudnn_activation-inl.h @@ -0,0 +1,132 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file cudnn_activation-inl.h + * \brief + * \author Bing Xu +*/ + +#ifndef MXNET_OPERATOR_CUDNN_ACTIVATION_INL_H_ +#define MXNET_OPERATOR_CUDNN_ACTIVATION_INL_H_ +#include +#include +#include "./activation-inl.h" + +namespace mxnet { +namespace op { +class CuDNNActivationOp : public Operator { + public: + explicit CuDNNActivationOp(ActivationParam param) { + param_ = param; + init_cudnn_ = false; + dtype_ = CUDNN_DATA_FLOAT; + switch (param_.act_type) { + case kReLU: + mode_ = CUDNN_ACTIVATION_RELU; + break; + case kSigmoid: + mode_ = CUDNN_ACTIVATION_SIGMOID; + break; + case kTanh: + mode_ = CUDNN_ACTIVATION_TANH; + break; + default: + LOG(FATAL) << "Not implmented"; + break; + } + } + + virtual void Forward(const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data, + const std::vector &aux_args) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(in_data.size(), 1); + CHECK_EQ(out_data.size(), 1); + Stream *s = ctx.get_stream(); + Tensor data = in_data[kData].get(s); + Tensor out = out_data[kOut].get(s); + float alpha = 1.0f; + float beta = 0.0f; + CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream::OwnHandle); + if (!init_cudnn_) { + this->Init(s, in_data, out_data); + } + CHECK_EQ(cudnnActivationForward(s->dnn_handle_, + mode_, + &alpha, + shape_desc_, + data.dptr_, + &beta, + shape_desc_, + out.dptr_), CUDNN_STATUS_SUCCESS); + } + + virtual void Backward(const OpContext &ctx, + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data, + const std::vector &req, + const std::vector &in_grad, + const std::vector &aux_args) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(out_grad.size(), 1); + CHECK_EQ(in_data.size(), 1); + CHECK_EQ(out_data.size(), 1); + CHECK_EQ(req.size(), 1); + CHECK_EQ(in_grad.size(), 1); + float alpha = 1.0f; + float beta = 0.0f; + Stream *s = ctx.get_stream(); + Tensor grad = out_grad[kOut].get(s); + Tensor data = in_data[kData].get(s); + Tensor output_data = out_data[kOut].get(s); + Tensor input_grad = in_grad[kData].get(s); + CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream::OwnHandle); + CHECK_EQ(cudnnActivationBackward(s->dnn_handle_, + mode_, + &alpha, + shape_desc_, + output_data.dptr_, + shape_desc_, + grad.dptr_, + shape_desc_, + data.dptr_, + &beta, + shape_desc_, + input_grad.dptr_), CUDNN_STATUS_SUCCESS); + } + + private: + inline void Init(mshadow::Stream *s, + const std::vector &in_data, + const std::vector &out_data) { + using namespace mshadow; + CHECK_EQ(in_data.size(), 1); + CHECK_EQ(out_data.size(), 1); + if (!init_cudnn_) { + init_cudnn_ = true; + Tensor data = in_data[kData].get(s); + Tensor out = out_data[kOut].get(s); + CHECK_EQ(data.shape_, out.shape_); + CHECK_EQ(cudnnCreateTensorDescriptor(&shape_desc_), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnSetTensor4dDescriptor(shape_desc_, + CUDNN_TENSOR_NCHW, + dtype_, + data.shape_[0], + data.shape_[1], + data.shape_[2], + data.shape_[3]), CUDNN_STATUS_SUCCESS); + } + } + bool init_cudnn_; + cudnnDataType_t dtype_; + cudnnActivationMode_t mode_; + cudnnTensorDescriptor_t shape_desc_; + ActivationParam param_; +}; // class CuDNNActivationOp +} // namespace op +} // namespace mxnet +#endif // MXNET_OPERATOR_CUDNN_ACTIVATION_INL_H_ diff --git a/src/operator/cudnn_convolution-inl.h b/src/operator/cudnn_convolution-inl.h new file mode 100644 index 000000000000..8b81818304e1 --- /dev/null +++ b/src/operator/cudnn_convolution-inl.h @@ -0,0 +1,275 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file cudnn_convolution-inl.h + * \brief + * \author Bing Xu +*/ +#ifndef MXNET_OPERATOR_CUDNN_CONVOLUTION_INL_H_ +#define MXNET_OPERATOR_CUDNN_CONVOLUTION_INL_H_ + +#include +#include +#include "./convolution-inl.h" + +namespace mxnet { +namespace op { +#if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 +class CuDNNConvolutionOp : public Operator { + public: + explicit CuDNNConvolutionOp(ConvolutionParam param) { + this->param_ = param; + init_cudnn_ = false; + // TODO(xxx): fp16 + dtype_ = CUDNN_DATA_FLOAT; + } + + ~CuDNNConvolutionOp() { + if (init_cudnn_) { + CHECK_EQ(cudnnDestroyTensorDescriptor(in_desc_), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnDestroyTensorDescriptor(out_desc_), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnDestroyTensorDescriptor(bias_desc_), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnDestroyFilterDescriptor(filter_desc_), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnDestroyConvolutionDescriptor(conv_desc_), CUDNN_STATUS_SUCCESS); + } + } + + virtual void Forward(const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data, + const std::vector &aux_args) { + using namespace mshadow; + size_t expected = param_.no_bias ? 2 : 3; + float alpha = 1.0f; + float beta = 0.0f; + CHECK_EQ(in_data.size(), expected); + CHECK_EQ(out_data.size(), 1); + Stream *s = ctx.get_stream(); + Tensor data = in_data[kData].get(s); + Tensor wmat = in_data[kWeight].get(s); + Tensor out = out_data[kOut].get(s); + CHECK_EQ(data.CheckContiguous(), true); + CHECK_EQ(wmat.CheckContiguous(), true); + CHECK_EQ(out.CheckContiguous(), true); + if (!init_cudnn_) { + Init(s, in_data, out_data); + } + CHECK_EQ(cudnnConvolutionForward(s->dnn_handle_, + &alpha, + in_desc_, + data.dptr_, + filter_desc_, + wmat.dptr_, + conv_desc_, + algo_, + temp_.dptr_, + param_.workspace, + &beta, + out_desc_, + out.dptr_), CUDNN_STATUS_SUCCESS); + if (!param_.no_bias) { + beta = 1.0f; + Tensor bias = in_data[kBias].get(s); + CHECK_EQ(cudnnAddTensor(s->dnn_handle_, + CUDNN_ADD_SAME_C, + &alpha, + bias_desc_, + bias.dptr_, + &beta, + out_desc_, + out.dptr_), CUDNN_STATUS_SUCCESS); + } + } + + virtual void Backward(const OpContext &ctx, + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data, + const std::vector &req, + const std::vector &in_grad, + const std::vector &aux_args) { + using namespace mshadow; + using namespace mshadow::expr; + float alpha = 1.0f; + float beta = 0.0f; + size_t expected = param_.no_bias == 0 ? 3 : 2; + CHECK_EQ(out_grad.size(), 1); + CHECK(in_data.size() == expected && in_grad.size() == expected); + // TODO(bing): think about how to support add to + CHECK_EQ(req[kWeight], kWriteTo); + Stream *s = ctx.get_stream(); + Tensor grad = out_grad[kOut].get(s); + Tensor wmat = in_data[kWeight].get(s); + Tensor gwmat = in_grad[kWeight].get(s); + Tensor data = in_data[kData].get(s); + Tensor gdata = in_grad[kData].get(s); + if (!param_.no_bias) { + Tensor gbias = in_grad[kBias].get(s); + CHECK_EQ(cudnnConvolutionBackwardBias(s->dnn_handle_, + &alpha, + out_desc_, + grad.dptr_, + &beta, + bias_desc_, + gbias.dptr_), CUDNN_STATUS_SUCCESS); + } + CHECK_EQ(cudnnConvolutionBackwardFilter_v3(s->dnn_handle_, + &alpha, + in_desc_, + data.dptr_, + out_desc_, + grad.dptr_, + conv_desc_, + back_algo_w_, + temp_.dptr_, + param_.workspace, + &beta, + filter_desc_, + gwmat.dptr_), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnConvolutionBackwardData_v3(s->dnn_handle_, + &alpha, + filter_desc_, + wmat.dptr_, + out_desc_, + grad.dptr_, + conv_desc_, + back_algo_, + temp_.dptr_, + param_.workspace, + &beta, + in_desc_, + gdata.dptr_), CUDNN_STATUS_SUCCESS); + } + + private: + inline void Init(mshadow::Stream *s, + const std::vector &in_data, + const std::vector &out_data) { + using namespace mshadow; + size_t expected = param_.no_bias ? 2 : 3; + CHECK_EQ(in_data.size(), expected); + CHECK_EQ(out_data.size(), 1); + if (!init_cudnn_) { + init_cudnn_ = true; + temp_.set_stream(s); + size_t workspace = static_cast(param_.workspace); + size_t back_size = 0; + size_t back_size_w = 0; + Tensor data = in_data[kData].get(s); + Tensor out = out_data[kOut].get(s); + CHECK_EQ(cudnnCreateTensorDescriptor(&in_desc_), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnCreateTensorDescriptor(&out_desc_), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnCreateTensorDescriptor(&bias_desc_), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnCreateFilterDescriptor(&filter_desc_), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnCreateConvolutionDescriptor(&conv_desc_), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnSetFilter4dDescriptor(filter_desc_, + dtype_, + param_.num_filter, + data.shape_[1], + param_.kernel[0], + param_.kernel[1]), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnSetConvolution2dDescriptor(conv_desc_, + param_.pad[0], + param_.pad[1], + param_.stride[0], + param_.stride[1], + 1, + 1, + CUDNN_CROSS_CORRELATION), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnSetTensor4dDescriptor(in_desc_, + CUDNN_TENSOR_NCHW, + dtype_, + data.shape_[0], + data.shape_[1], + data.shape_[2], + data.shape_[3]), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnSetTensor4dDescriptor(out_desc_, + CUDNN_TENSOR_NCHW, + dtype_, + out.shape_[0], + out.shape_[1], + out.shape_[2], + out.shape_[3]), CUDNN_STATUS_SUCCESS); + if (!param_.no_bias) { + Tensor bias = in_data[kBias].get(s); + CHECK_EQ(cudnnSetTensor4dDescriptor(bias_desc_, + CUDNN_TENSOR_NCHW, + dtype_, + 1, + bias.shape_[0], + 1, + 1), CUDNN_STATUS_SUCCESS); + } + CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream::OwnHandle); + CHECK_EQ(cudnnGetConvolutionForwardAlgorithm(s->dnn_handle_, + in_desc_, + filter_desc_, + conv_desc_, + out_desc_, + CUDNN_CONVOLUTION_FWD_PREFER_FASTEST, + param_.workspace, + &algo_), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnGetConvolutionBackwardFilterAlgorithm(s->dnn_handle_, + in_desc_, + out_desc_, + conv_desc_, + filter_desc_, + CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST, + param_.workspace, + &back_algo_w_), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnGetConvolutionBackwardDataAlgorithm(s->dnn_handle_, + filter_desc_, + out_desc_, + conv_desc_, + in_desc_, + CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST, + param_.workspace, + &back_algo_), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnGetConvolutionBackwardDataWorkspaceSize(s->dnn_handle_, + filter_desc_, + out_desc_, + conv_desc_, + in_desc_, + back_algo_, + &back_size), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnGetConvolutionBackwardFilterWorkspaceSize(s->dnn_handle_, + in_desc_, + out_desc_, + conv_desc_, + filter_desc_, + back_algo_w_, + &back_size_w), CUDNN_STATUS_SUCCESS); + back_size = std::max(back_size, back_size_w); + CHECK_EQ(cudnnGetConvolutionForwardWorkspaceSize(s->dnn_handle_, + in_desc_, + filter_desc_, + conv_desc_, + out_desc_, + algo_, + &workspace), CUDNN_STATUS_SUCCESS); + workspace = std::max(workspace, back_size); + param_.workspace = workspace; + // TODO(bing): wait resource allocation + temp_.Resize(mshadow::Shape1(workspace / sizeof(real_t) + 1), 0.0f); + } + } + + bool init_cudnn_; + cudnnDataType_t dtype_; + cudnnTensorDescriptor_t in_desc_; + cudnnTensorDescriptor_t out_desc_; + cudnnTensorDescriptor_t bias_desc_; + cudnnFilterDescriptor_t filter_desc_; + cudnnConvolutionDescriptor_t conv_desc_; + cudnnConvolutionFwdAlgo_t algo_; + cudnnConvolutionBwdDataAlgo_t back_algo_; + cudnnConvolutionBwdFilterAlgo_t back_algo_w_; + ConvolutionParam param_; + // TODO(bing): remove when we have resource manager + mshadow::TensorContainer temp_; +}; +#endif // __CUDACC__ && CUDNN +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_CUDNN_CONVOLUTION_INL_H_ diff --git a/src/operator/cudnn_pooling-inl.h b/src/operator/cudnn_pooling-inl.h new file mode 100644 index 000000000000..83faeee70435 --- /dev/null +++ b/src/operator/cudnn_pooling-inl.h @@ -0,0 +1,154 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file cudnn_pooling-inl.h + * \brief + * \author Bing Xu +*/ + +#ifndef MXNET_OPERATOR_CUDNN_POOLING_INL_H_ +#define MXNET_OPERATOR_CUDNN_POOLING_INL_H_ +#include +#include +#include "./pooling-inl.h" + +namespace mxnet { +namespace op { + +class CuDNNPoolingOp : public Operator { + public: + explicit CuDNNPoolingOp(PoolingParam p) { + param_ = p; + init_cudnn_ = false; + // TODO(xxx): fp16 + dtype_ = CUDNN_DATA_FLOAT; + switch (param_.pool_type) { + case kMaxPooling: + mode_ = CUDNN_POOLING_MAX; + break; + case kAvgPooling: + mode_ = CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING; + break; + default: + LOG(FATAL) << "Not implmented"; + } + } + + virtual void Forward(const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data, + const std::vector &aux_args) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(in_data.size(), 1); + CHECK_EQ(out_data.size(), 1); + Stream *s = ctx.get_stream(); + Tensor data = in_data[kData].get(s); + Tensor out = out_data[kOut].get(s); + CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream::OwnHandle); + if (!init_cudnn_) { + this->Init(s, in_data, out_data); + } + float alpha = 1.0f; + float beta = 0.0f; + CHECK_EQ(data.CheckContiguous(), true); + CHECK_EQ(out.CheckContiguous(), true); + CHECK_EQ(cudnnPoolingForward(s->dnn_handle_, + pooling_desc_, + &alpha, + in_desc_, + data.dptr_, + &beta, + out_desc_, + out.dptr_), CUDNN_STATUS_SUCCESS); + } + + virtual void Backward(const OpContext &ctx, + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data, + const std::vector &req, + const std::vector &in_grad, + const std::vector &aux_args) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(out_grad.size(), 1); + CHECK_EQ(in_data.size(), 1); + CHECK_EQ(out_data.size(), 1); + CHECK_EQ(req.size(), 1); + CHECK_EQ(in_grad.size(), 1); + + Stream *s = ctx.get_stream(); + Tensor m_out_grad = out_grad[kOut].get(s); + Tensor m_in_data = in_data[kData].get(s); + Tensor m_out_data = out_data[kOut].get(s); + Tensor m_in_grad = in_grad[kData].get(s); + CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream::OwnHandle); + float alpha = 1.0f; + float beta = 0.0f; + CHECK_EQ(cudnnPoolingBackward(s->dnn_handle_, + pooling_desc_, + &alpha, + out_desc_, + m_out_data.dptr_, + out_desc_, + m_out_grad.dptr_, + in_desc_, + m_in_data.dptr_, + &beta, + in_desc_, + m_in_grad.dptr_), CUDNN_STATUS_SUCCESS); + } + + private: + inline void Init(mshadow::Stream *s, + const std::vector &in_data, + const std::vector &out_data) { + using namespace mshadow; + CHECK_EQ(in_data.size(), 1); + CHECK_EQ(out_data.size(), 1); + if (!init_cudnn_) { + init_cudnn_ = true; + Tensor data = in_data[kData].get(s); + Tensor out = out_data[kOut].get(s); + CHECK_EQ(cudnnCreatePoolingDescriptor(&pooling_desc_), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnCreateTensorDescriptor(&in_desc_), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnCreateTensorDescriptor(&out_desc_), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnSetTensor4dDescriptor(in_desc_, + CUDNN_TENSOR_NCHW, + dtype_, + data.shape_[0], + data.shape_[1], + data.shape_[2], + data.shape_[3]), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnSetTensor4dDescriptor(out_desc_, + CUDNN_TENSOR_NCHW, + dtype_, + out.shape_[0], + out.shape_[1], + out.shape_[2], + out.shape_[3]), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnSetPooling2dDescriptor(pooling_desc_, + mode_, + param_.kernel[0], + param_.kernel[1], + param_.pad[0], + param_.pad[1], + param_.stride[0], + param_.stride[1]), CUDNN_STATUS_SUCCESS); + } + } + bool init_cudnn_; + cudnnDataType_t dtype_; + cudnnHandle_t handle_; + cudnnPoolingMode_t mode_; + cudnnTensorDescriptor_t in_desc_; + cudnnTensorDescriptor_t out_desc_; + cudnnPoolingDescriptor_t pooling_desc_; + PoolingParam param_; +}; // class CuDNNPoolingOp +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_CUDNN_POOLING_INL_H_ + diff --git a/src/operator/fully_connected-inl.h b/src/operator/fully_connected-inl.h index 75fe14d3aab8..35cb035d16f8 100644 --- a/src/operator/fully_connected-inl.h +++ b/src/operator/fully_connected-inl.h @@ -63,6 +63,10 @@ class FullyConnectedOp : public Operator { // maybe need blas handle from context // TODO(bing): judge shape to remove flatten op Stream *s = ctx.get_stream(); + #if defined(__CUDACC__) + CHECK_EQ(s->blas_handle_ownership_, Stream::OwnHandle) + << "Must init CuBLAS handle in stream"; + #endif // __CUDACC__ Tensor data = in_data[kData].FlatTo2D(s); Tensor wmat = in_data[kWeight].get(s); Tensor out = out_data[kOut].FlatTo2D(s); @@ -92,6 +96,10 @@ class FullyConnectedOp : public Operator { Tensor data = in_data[kData].FlatTo2D(s); Tensor wmat = in_data[kWeight].get(s); Tensor grad = out_grad[kOut].FlatTo2D(s); + #if defined(__CUDACC__) + CHECK_EQ(s->blas_handle_ownership_, Stream::OwnHandle) + << "Must init CuBLAS handle in stream"; + #endif // backprop CHECK_NE(req[kWeight], kWriteInplace) << "cannot write weight inplace"; // gradient of weight diff --git a/src/operator/pooling-inl.h b/src/operator/pooling-inl.h index b0d483ef0217..5748325d5835 100644 --- a/src/operator/pooling-inl.h +++ b/src/operator/pooling-inl.h @@ -201,7 +201,11 @@ class PoolingProp : public OperatorProperty { const std::vector &in_data, const std::vector &out_data, const std::vector &in_grad) const override { + #if MXNET_USE_CUDNN == 1 + return {}; + #else return {{in_data[kData], in_grad[kData]}}; + #endif } Operator* CreateOperator(Context ctx) const; diff --git a/src/operator/pooling.cu b/src/operator/pooling.cu index 5037050ccd6f..df9547bf4a1e 100644 --- a/src/operator/pooling.cu +++ b/src/operator/pooling.cu @@ -6,11 +6,17 @@ */ #include "./pooling-inl.h" +#if MXNET_USE_CUDNN == 1 +#include "./cudnn_pooling-inl.h" +#endif // MXNET_USE_CUDNN namespace mxnet { namespace op { template<> Operator *CreateOp(PoolingParam param) { + #if MXNET_USE_CUDNN == 1 + return new CuDNNPoolingOp(param); + #else switch (param.pool_type) { case kMaxPooling: return new PoolingOp(param); case kAvgPooling: return new PoolingOp(param); @@ -19,6 +25,7 @@ Operator *CreateOp(PoolingParam param) { LOG(FATAL) << "unknown activation type"; return NULL; } + #endif // MXNET_USE_CUDNN } } // namespace op diff --git a/src/symbol/graph_executor.cc b/src/symbol/graph_executor.cc index 914c8f8b8c9f..863355c28937 100644 --- a/src/symbol/graph_executor.cc +++ b/src/symbol/graph_executor.cc @@ -323,6 +323,13 @@ void GraphExecutor::InitDataEntryInfo(const std::vector &in_args, } void GraphExecutor::InitDataEntryMemory() { + // setup the temp ref counter for allocator algorithms + for (OpNode &op : op_nodes_) { + for (DataEntryInfo &node : op.outputs) { + node.temp_ref_count = node.ref_count; + } + } + // use allocator to allocate memory. GraphStorageAllocator allocator(&graph_); for (size_t i = 0; i < topo_order_.size(); ++i) { @@ -337,7 +344,7 @@ void GraphExecutor::InitDataEntryMemory() { for (StaticGraph::DataEntry e : graph_.nodes[nid].inputs) { DataEntryInfo &info = op_nodes_[e.source_id].outputs[e.index]; CHECK_NE(info.type, kNotInitialized); - CHECK_NE(info.ref_count, 0); + CHECK_NE(info.temp_ref_count, 0); in_data.push_back(&info); } std::vector out_data(op_nodes_[nid].outputs.size()); @@ -350,7 +357,7 @@ void GraphExecutor::InitDataEntryMemory() { for (std::pair kv : inplace) { DataEntryInfo* in = kv.first; DataEntryInfo* out = kv.second; - if (in->ref_count == 1 && + if (in->temp_ref_count == 1 && in->type == kInternalAllocated && out->type == kNotInitialized) { // we can only do inplace if we are last user of in @@ -359,13 +366,13 @@ void GraphExecutor::InitDataEntryMemory() { out->op_req = kWriteInplace; out->storage_id = in->storage_id; // set inplace op id - in->ref_count = 0; + in->temp_ref_count = 0; in->inplace_op_id = static_cast(nid); } } // allocate output, for (DataEntryInfo *out : out_data) { - if (out->op_req == kNullOp && out->ref_count != 0) { + if (out->op_req == kNullOp && out->temp_ref_count != 0) { out->op_req = kWriteTo; } if (out->type == kNotInitialized) { @@ -376,20 +383,20 @@ void GraphExecutor::InitDataEntryMemory() { } // then free inputs for (DataEntryInfo *in : in_data) { - // ref_count == 0 means it is taken by inplace op - if (in->ref_count == 0) { + // temp_ref_count == 0 means it is taken by inplace op + if (in->temp_ref_count == 0) { CHECK_EQ(in->inplace_op_id, static_cast(nid)); continue; } // if we decrease it to zero, means we are ready to relase - --in->ref_count; - if (in->ref_count == 0 && in->type == kInternalAllocated) { + --in->temp_ref_count; + if (in->temp_ref_count == 0 && in->type == kInternalAllocated) { allocator.Release(in->storage_id, nid); } } - // check out again, if there is ref_count == 0, release it + // check out again, if there is temp_ref_count == 0, release it for (DataEntryInfo *out : out_data) { - if (out->ref_count == 0 && out->type == kInternalAllocated) { + if (out->temp_ref_count == 0 && out->type == kInternalAllocated) { allocator.Release(out->storage_id, nid); } } @@ -493,13 +500,26 @@ void GraphExecutor::Forward(bool is_train) { } void GraphExecutor::Backward(const std::vector &head_grads) { - CHECK_EQ(head_grad_nodes_.size(), head_grads.size()); - for (size_t i = 0; i < head_grad_nodes_.size(); ++i) { - uint32_t nid = head_grad_nodes_[i]; - CHECK(graph_.nodes[nid].is_variable()); - DataEntryInfo &info = op_nodes_[nid].outputs[0]; - CHECK_EQ(info.type, kTobeBindByExternal); - info.data = head_grads[i]; + if (head_grads.size() != 0) { + // TODO(bing, min): consider pass a map for backward + CHECK_EQ(head_grad_nodes_.size(), head_grads.size()); + for (size_t i = 0; i < head_grad_nodes_.size(); ++i) { + uint32_t nid = head_grad_nodes_[i]; + CHECK(graph_.nodes[nid].is_variable()); + DataEntryInfo &info = op_nodes_[nid].outputs[0]; + CHECK_EQ(info.type, kTobeBindByExternal); + info.data = head_grads[i]; + } + } else { + // check all the head_grad_nodes need to have zero ref_count + // loss function do not need out_grad + for (size_t i = 0; i < head_grad_nodes_.size(); ++i) { + uint32_t nid = head_grad_nodes_[i]; + DataEntryInfo &info = op_nodes_[nid].outputs[0]; + CHECK_EQ(info.ref_count, 0) + << "Because the last operator is not Loss function, " + << "head_gradient is required in calling backward."; + } } RunOps(true, num_forward_nodes_, topo_order_.size()); } diff --git a/src/symbol/graph_executor.h b/src/symbol/graph_executor.h index af2160415e8d..074fafa0c571 100644 --- a/src/symbol/graph_executor.h +++ b/src/symbol/graph_executor.h @@ -79,6 +79,8 @@ class GraphExecutor : public Executor { // reference count on how many times this entry is being used. // That is how many operators and heads need this DataEntry // this is a temporal variable that is used during initialization. + uint32_t temp_ref_count; + // real permanent ref count uint32_t ref_count; // constructor DataEntryInfo() @@ -86,7 +88,7 @@ class GraphExecutor : public Executor { inplace_op_id(-1), type(kNotInitialized), storage_id(GraphStorageAllocator::kBadStorageID), - ref_count(0) {} + temp_ref_count(0), ref_count(0) {} }; // all the information needed to push the op to engine struct OpExecEntry { diff --git a/tests/python/test_conv.py b/tests/python/test_conv.py index 9ab34ce1c8ae..d63a0542ce7a 100644 --- a/tests/python/test_conv.py +++ b/tests/python/test_conv.py @@ -12,12 +12,12 @@ def CalAcc(out, label): # symbol net batch_size = 100 data = mx.symbol.Variable('data') -conv1= mx.symbol.Convolution(data = data, name='conv1', num_filter=32, kernel=(3,3), stride=(2,2), nstep=100) +conv1= mx.symbol.Convolution(data = data, name='conv1', num_filter=32, kernel=(3,3), stride=(2,2)) bn1 = mx.symbol.BatchNorm(data = conv1, name="bn1") act1 = mx.symbol.Activation(data = bn1, name='relu1', act_type="relu") mp1 = mx.symbol.Pooling(data = act1, name = 'mp1', kernel=(2,2), stride=(2,2), pool_type='max') -conv2= mx.symbol.Convolution(data = mp1, name='conv2', num_filter=32, kernel=(3,3), stride=(2,2), nstep=100) +conv2= mx.symbol.Convolution(data = mp1, name='conv2', num_filter=32, kernel=(3,3), stride=(2,2)) bn2 = mx.symbol.BatchNorm(data = conv2, name="bn2") act2 = mx.symbol.Activation(data = bn2, name='relu2', act_type="relu") mp2 = mx.symbol.Pooling(data = act2, name = 'mp2', kernel=(2,2), stride=(2,2), pool_type='max')