diff --git a/python/mxnet/model.py b/python/mxnet/model.py index 0259a97d4594..aecf63c86b45 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -85,12 +85,10 @@ def _initialize_kvstore(kvstore, param_arrays, arg_params, param_names, if update_on_kvstore: kvstore.pull(idx, param_on_devs, priority=-idx) -def _update_params_on_kvstore(param_arrays, grad_arrays, kvstore, - stype_dict=None, param_names=None): +def _update_params_on_kvstore(param_arrays, grad_arrays, kvstore, param_names=None): """Perform update of param_arrays from grad_arrays on kvstore. If `param_names` is None or kvstore doesn't have a `name2idx` dictionary, the index of a param is determined by the order it appears in `param_arrays`. """ - stype_dict = {} if stype_dict is None else stype_dict for i, pair in enumerate(zip(param_arrays, grad_arrays)): arg_list, grad_list = pair if grad_list[0] is None: @@ -99,31 +97,18 @@ def _update_params_on_kvstore(param_arrays, grad_arrays, kvstore, if param_names is not None: name = param_names[i] index = index if name not in kvstore.name2idx else kvstore.name2idx[name] - # cast storage type if stype doesn't match - if name in stype_dict: - for j, grad in enumerate(grad_list): - stype = stype_dict[name] - if grad_list[j].storage_type != stype: - grad_list[j] = nd.cast_storage(grad, stype) # push gradient, priority is negative index kvstore.push(index, grad_list, priority=-index) # pull back the weights kvstore.pull(index, arg_list, priority=-index) def _update_params(param_arrays, grad_arrays, updater, num_device, - kvstore=None, stype_dict=None, param_names=None): + kvstore=None, param_names=None): """Perform update of param_arrays from grad_arrays not on kvstore.""" - stype_dict = {} if stype_dict is None else stype_dict for i, pair in enumerate(zip(param_arrays, grad_arrays)): arg_list, grad_list = pair if grad_list[0] is None: continue - # cast storage type if stype doesn't match - if param_names is not None and param_names[i] in stype_dict: - for j, grad in enumerate(grad_list): - stype = stype_dict[param_names[i]] - if grad_list[j].storage_type != stype: - grad_list[j] = nd.cast_storage(grad, stype) index = i if kvstore: if param_names is not None: @@ -136,8 +121,11 @@ def _update_params(param_arrays, grad_arrays, updater, num_device, for k, p in enumerate(zip(arg_list, grad_list)): # faked an index here, to make optimizer create diff # state for the same index but on diff devs, TODO(mli) - # use a better solution latter + # use a better solution later w, g = p + # cast storage type if stype doesn't match + if g.storage_type != w.storage_type: + g = nd.cast_storage(g, w.storage_type) updater(index*num_device+k, g, w) diff --git a/python/mxnet/module/base_module.py b/python/mxnet/module/base_module.py index c78daa1137c8..820841087a9c 100644 --- a/python/mxnet/module/base_module.py +++ b/python/mxnet/module/base_module.py @@ -849,17 +849,9 @@ def get_input_grads(self, merge_multi_context=True): """ raise NotImplementedError() - def update(self, storage_type_dict=None): + def update(self): """Updates parameters according to the installed optimizer and the gradients computed - in the previous forward-backward batch. The storage type of parameters is casted according - to `storage_type_dict`, if provided. - - Parameters - ---------- - storage_type_dict: dict of str to str - Defaults to ``None``. Desired storage types of parameters for parameter update. If the - parameter gradient is not of desired storage type, its storage type will be casted - before the update. + in the previous forward-backward batch. Examples -------- diff --git a/python/mxnet/module/bucketing_module.py b/python/mxnet/module/bucketing_module.py index ae10e8e401d0..11922ddafb56 100644 --- a/python/mxnet/module/bucketing_module.py +++ b/python/mxnet/module/bucketing_module.py @@ -399,13 +399,13 @@ def backward(self, out_grads=None): assert self.binded and self.params_initialized self._curr_module.backward(out_grads=out_grads) - def update(self, storage_type_dict=None): + def update(self): """Updates parameters according to installed optimizer and the gradient computed in the previous forward-backward cycle. """ assert self.binded and self.params_initialized and self.optimizer_initialized self._params_dirty = True - self._curr_module.update(storage_type_dict=storage_type_dict) + self._curr_module.update() def get_outputs(self, merge_multi_context=True): """Gets outputs from a previous forward computation. diff --git a/python/mxnet/module/module.py b/python/mxnet/module/module.py index a0eb19dafccc..26221078cee1 100644 --- a/python/mxnet/module/module.py +++ b/python/mxnet/module/module.py @@ -562,7 +562,7 @@ def backward(self, out_grads=None): assert self.binded and self.params_initialized self._exec_group.backward(out_grads=out_grads) - def update(self, storage_type_dict=None): + def update(self): """Updates parameters according to the installed optimizer and the gradients computed in the previous forward-backward batch. @@ -577,7 +577,6 @@ def update(self, storage_type_dict=None): _update_params_on_kvstore(self._exec_group.param_arrays, self._exec_group.grad_arrays, self._kvstore, - stype_dict=storage_type_dict, param_names=self._param_names) else: _update_params(self._exec_group.param_arrays, diff --git a/python/mxnet/module/python_module.py b/python/mxnet/module/python_module.py index 82dcb06aa020..f46ea280aaff 100644 --- a/python/mxnet/module/python_module.py +++ b/python/mxnet/module/python_module.py @@ -110,7 +110,7 @@ def init_params(self, initializer=Uniform(0.01), arg_params=None, aux_params=Non """ pass - def update(self, storage_type_dict=None): + def update(self): """Updates parameters according to the installed optimizer and the gradients computed in the previous forward-backward batch. Currently we do nothing here. Subclass should override this method if contains parameters. diff --git a/python/mxnet/module/sequential_module.py b/python/mxnet/module/sequential_module.py index 383286642e0c..21e30fb3b0ce 100644 --- a/python/mxnet/module/sequential_module.py +++ b/python/mxnet/module/sequential_module.py @@ -344,14 +344,14 @@ def backward(self, out_grads=None): out_grads = module.get_input_grads() - def update(self, storage_type_dict=None): + def update(self): """Updates parameters according to installed optimizer and the gradient computed in the previous forward-backward cycle. """ assert self.binded and self.params_initialized and self.optimizer_initialized for module in self._modules: - module.update(storage_type_dict=storage_type_dict) + module.update() def get_outputs(self, merge_multi_context=True): """Gets outputs from a previous forward computation. diff --git a/python/mxnet/ndarray.py b/python/mxnet/ndarray.py index 8e8d3ffebbd4..6167369110d7 100644 --- a/python/mxnet/ndarray.py +++ b/python/mxnet/ndarray.py @@ -1022,7 +1022,7 @@ def empty(shape, ctx=None, dtype=mx_real_t): ctx = Context.default_ctx return NDArray(handle=_new_alloc_handle(shape, ctx, False, dtype)) -def zeros(shape, ctx=None, dtype=mx_real_t, **kwargs): +def zeros(shape, ctx=None, dtype=None, **kwargs): """Returns a new array filled with all zeros, with the given shape and type. Parameters @@ -1053,11 +1053,12 @@ def zeros(shape, ctx=None, dtype=mx_real_t, **kwargs): # pylint: disable= unused-argument if ctx is None: ctx = Context.default_ctx + dtype = mx_real_t if dtype is None else dtype # pylint: disable= no-member, protected-access return _internal._zeros(shape=shape, ctx=ctx, dtype=dtype, **kwargs) # pylint: enable= no-member, protected-access -def ones(shape, ctx=None, dtype=mx_real_t, **kwargs): +def ones(shape, ctx=None, dtype=None, **kwargs): """Returns a new array filled with all ones, with the given shape and type. Parameters @@ -1089,6 +1090,7 @@ def ones(shape, ctx=None, dtype=mx_real_t, **kwargs): # pylint: disable= unused-argument if ctx is None: ctx = Context.default_ctx + dtype = mx_real_t if dtype is None else dtype # pylint: disable= no-member, protected-access return _internal._ones(shape=shape, ctx=ctx, dtype=dtype, **kwargs) # pylint: enable= no-member, protected-access diff --git a/python/mxnet/sparse_ndarray.py b/python/mxnet/sparse_ndarray.py index bc06fc1d1113..fe3239ae0bfa 100644 --- a/python/mxnet/sparse_ndarray.py +++ b/python/mxnet/sparse_ndarray.py @@ -600,7 +600,7 @@ def zeros(storage_type, shape, ctx=None, dtype=None, aux_types=None, **kwargs): array([[ 0., 0.]], dtype=float16) """ if storage_type == 'default': - return ndarray.zeros(shape, ctx, dtype, **kwargs) + return ndarray.zeros(shape, ctx=ctx, dtype=dtype, **kwargs) if ctx is None: ctx = Context.default_ctx dtype = mx_real_t if dtype is None else dtype diff --git a/src/kvstore/comm.h b/src/kvstore/comm.h index 1197d4ef3edb..e1ab5c9557e0 100644 --- a/src/kvstore/comm.h +++ b/src/kvstore/comm.h @@ -29,9 +29,10 @@ class Comm { } virtual ~Comm() { } /** - * \brief init key with the data shape + * \brief init key with the data shape and storage shape */ - virtual void Init(int key, const TShape& shape, int dtype = mshadow::kFloat32) = 0; + virtual void Init(int key, const NDArrayStorageType stype, + const TShape& shape, int dtype = mshadow::kFloat32) = 0; /** * \brief returns src[0] + .. + src[src.size()-1] */ @@ -67,8 +68,13 @@ class CommCPU : public Comm { } virtual ~CommCPU() { } - void Init(int key, const TShape& shape, int type = mshadow::kFloat32) override { - merge_buf_[key].merged = NDArray(shape, pinned_ctx_, false, type); + void Init(int key, const NDArrayStorageType stype, const TShape& shape, + int type = mshadow::kFloat32) override { + if (stype == kDefaultStorage) { + merge_buf_[key].merged = NDArray(shape, pinned_ctx_, false, type); + } else { + merge_buf_[key].merged = NDArray(stype, shape, pinned_ctx_, true, type); + } } const NDArray& Reduce(int key, const std::vector& src, @@ -78,29 +84,56 @@ class CommCPU : public Comm { if (src.size() == 1) { return src[0]; } - std::vector const_vars(src.size() - 1); - std::vector reduce(src.size()); auto& buf = merge_buf_[key]; - CopyFromTo(src[0], &buf.merged, priority); - reduce[0] = buf.merged; + if (buf.merged.storage_type() == kDefaultStorage) { + std::vector const_vars(src.size() - 1); + std::vector reduce(src.size()); + CopyFromTo(src[0], &buf.merged, priority); + reduce[0] = buf.merged; - if (buf.copy_buf.empty()) { - buf.copy_buf.resize(src.size()-1); - for (size_t j = 0; j < src.size() - 1; ++j) { - buf.copy_buf[j] = NDArray( - src[0].shape(), pinned_ctx_, false, src[0].dtype()); + if (buf.copy_buf.empty()) { + buf.copy_buf.resize(src.size()-1); + for (size_t j = 0; j < src.size() - 1; ++j) { + // allocate NDArray basd on storage type + buf.copy_buf[j] = NDArray( + src[0].shape(), pinned_ctx_, false, src[0].dtype()); + } + } + for (size_t i = 1; i < src.size(); ++i) { + CopyFromTo(src[i], &(buf.copy_buf[i-1]), priority); + reduce[i] = buf.copy_buf[i-1]; + const_vars[i-1] = reduce[i].var(); } - } - for (size_t i = 1; i < src.size(); ++i) { - CopyFromTo(src[i], &(buf.copy_buf[i-1]), priority); - reduce[i] = buf.copy_buf[i-1]; - const_vars[i-1] = reduce[i].var(); - } - Engine::Get()->PushSync([reduce, this](RunContext rctx) { - ReduceSumCPU(reduce); - }, Context::CPU(), const_vars, {reduce[0].var()}, - FnProperty::kCPUPrioritized, priority, PROFILER_MESSAGE("KVStoreReduce")); + Engine::Get()->PushSync([reduce, this](RunContext rctx) { + ReduceSumCPU(reduce); + }, Context::CPU(), const_vars, {reduce[0].var()}, + FnProperty::kCPUPrioritized, priority, PROFILER_MESSAGE("KVStoreReduce")); + + } else { + // buf.merged is a sparse ndarray. + std::vector const_vars(src.size()); + std::vector reduce(src.size()); + + if (buf.copy_buf.empty()) { + buf.copy_buf.resize(src.size()); + for (size_t j = 0; j < src.size(); ++j) { + buf.copy_buf[j] = NDArray( + src[0].storage_type(), src[0].shape(), pinned_ctx_, true, src[0].dtype()); + } + } + for (size_t i = 0; i < src.size(); ++i) { + CopyFromTo(src[i], &(buf.copy_buf[i]), priority); + reduce[i] = buf.copy_buf[i]; + const_vars[i] = reduce[i].var(); + } + auto result = buf.merged; + Engine::Get()->PushSync([reduce, result, this](RunContext rctx) { + NDArray out = result; + ReduceSumCPUEx(reduce, &out); + }, Context::CPU(), const_vars, {result.var()}, + FnProperty::kCPUPrioritized, priority, PROFILER_MESSAGE("KVStoreReduce")); + } return buf.merged; } @@ -133,6 +166,79 @@ class CommCPU : public Comm { }); } + // serial implementation of reduce sum for row sparse NDArray. + // TODO(haibin) use openmp kernel to parallelize the summation + inline void ReduceSumCPUEx(const std::vector &in, NDArray *out) { + using namespace rowsparse; + using namespace mshadow; + auto stype = out->storage_type(); + CHECK_EQ(stype, kRowSparseStorage) << "Unexpected storage type " << stype; + size_t total_num_rows = 0; + size_t num_in = in.size(); + // skip the ones with empty indices and values + std::vector skip(num_in, false); + // the values tensor of the inputs + MSHADOW_TYPE_SWITCH(out->dtype(), DType, { + MSHADOW_INT_TYPE_SWITCH(out->aux_type(kIdx), IType, { + std::vector> in_vals(num_in); + std::vector> in_indices(num_in); + // offset to the values tensor of all inputs + std::vector offsets(num_in, 0); + std::vector num_rows(num_in, 0); + for (size_t i = 0; i < num_in; i++) { + if (!in[i].storage_initialized()) { + skip[i] = true; + continue; + } + auto size = in[i].aux_shape(kIdx).Size(); + num_rows[i] = size; + total_num_rows += size; + in_vals[i] = in[i].data().FlatTo2D(); + in_indices[i] = in[i].aux_data(kIdx).FlatTo1D(); + } + std::vector indices; + indices.reserve(total_num_rows); + // gather indices from all inputs + for (size_t i = 0; i < num_in; i++) { + for (size_t j = 0; j < num_rows[i]; j++) { + indices.emplace_back(in_indices[i][j]); + } + } + CHECK_EQ(indices.size(), total_num_rows); + // dedup indices + std::sort(indices.begin(), indices.end()); + indices.resize(std::unique(indices.begin(), indices.end()) - indices.begin()); + // the one left are unique non-zero rows + size_t nnr = indices.size(); + // allocate memory for output + out->CheckAndAlloc({Shape1(nnr)}); + auto idx_data = out->aux_data(kIdx).FlatTo1D(); + auto val_data = out->data().FlatTo2D(); + + for (size_t i = 0; i < nnr; i++) { + // copy indices back + idx_data[i] = indices[i]; + bool zeros = true; + for (size_t j = 0; j < num_in; j++) { + if (skip[j]) continue; + size_t offset = offsets[j]; + if (offset < num_rows[j]) { + if (indices[i] == in_indices[j][offset]) { + if (zeros) { + Copy(val_data[i], in_vals[j][offset], nullptr); + zeros = false; + } else { + val_data[i] += in_vals[j][offset]; + } + offsets[j] += 1; + } + } + } + } + }); + }); + } + template inline static void ReduceSumCPU( const std::vector &dptr, size_t offset, index_t size) { @@ -216,8 +322,13 @@ class CommDevice : public Comm { virtual ~CommDevice() { } - void Init(int key, const TShape& shape, int dtype = mshadow::kFloat32) override { - sorted_key_attrs_.push_back(std::make_tuple(key, shape, dtype)); + void Init(int key, const NDArrayStorageType stype, const TShape& shape, + int dtype = mshadow::kFloat32) override { + if (stype == kDefaultStorage) { + sorted_key_attrs_.push_back(std::make_tuple(key, shape, dtype)); + } else { + LOG(FATAL) << "storage type " << stype << " not implemented for device yet"; + } } const NDArray& Reduce(int key, const std::vector& src, diff --git a/src/kvstore/kvstore_dist.h b/src/kvstore/kvstore_dist.h index 5f5a0cc67a64..62ec06c30fab 100644 --- a/src/kvstore/kvstore_dist.h +++ b/src/kvstore/kvstore_dist.h @@ -63,7 +63,7 @@ class KVStoreDist : public KVStoreLocal { const std::vector& values) override { CheckUnique(keys); for (size_t i = 0; i < keys.size(); ++i) { - comm_->Init(keys[i], values[i].shape(), values[i].dtype()); + comm_->Init(keys[i], values[i].storage_type(), values[i].shape(), values[i].dtype()); } if (get_rank() == 0) { Push_(keys, values, 0, false); diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h index caa57a20d46e..5506f2c76bb3 100644 --- a/src/kvstore/kvstore_local.h +++ b/src/kvstore/kvstore_local.h @@ -43,7 +43,7 @@ class KVStoreLocal : public KVStore { CHECK(local_.find(keys[i]) == local_.end()) << "duplicate init of key " << keys[i]; local_[keys[i]] = values[i].Copy(pinned_ctx_); - comm_->Init(keys[i], values[i].shape(), values[i].dtype()); + comm_->Init(keys[i], values[i].storage_type(), values[i].shape(), values[i].dtype()); } } @@ -67,7 +67,11 @@ class KVStoreLocal : public KVStore { } updater_(key, merged, &local); } else { - local = merged; + if (merged.storage_type() != local.storage_type()) { + local = merged.Copy(local.ctx()); + } else { + local = merged; + } } } } diff --git a/tests/python/unittest/test_kvstore.py b/tests/python/unittest/test_kvstore.py index dd8149d4822e..a64bfcae0868 100644 --- a/tests/python/unittest/test_kvstore.py +++ b/tests/python/unittest/test_kvstore.py @@ -1,19 +1,19 @@ # pylint: skip-file import mxnet as mx import numpy as np +from mxnet.test_utils import rand_ndarray, assert_almost_equal shape = (4, 4) keys = [5, 7, 11] -def init_kv(): +def init_kv(stype='default'): """init kv """ kv = mx.kv.create() # single - kv.init(3, mx.nd.zeros(shape)) + kv.init(3, mx.sparse_nd.zeros(stype, shape)) # list - kv.init(keys, [mx.nd.zeros(shape)] * len(keys)) + kv.init(keys, [mx.sparse_nd.zeros(stype, shape)] * len(keys)) return kv - def check_diff_to_scalar(A, x): """ assert A == x""" assert(np.sum(np.abs((A - x).asnumpy())) == 0) @@ -74,6 +74,42 @@ def test_aggregator(): for v in vv: check_diff_to_scalar(v, num_devs * 2.0) +def test_sparse_aggregator(): + """aggregate sparse ndarray on muliple devices""" + + stype = 'row_sparse' + kv = init_kv(stype) + + # devices + num_devs = 4 + devs = [mx.Context('cpu', i) for i in range(num_devs)] + + # single + vals = [rand_ndarray(shape, stype).copyto(devs[i]) for i in range(num_devs)] + expected_sum = np.zeros(shape) + for v in vals: + expected_sum += v.asnumpy() + + kv.push(3, vals) + kv.pull(3, out = vals) + result_sum = np.zeros(shape) + for v in vals: + result_sum += v.asnumpy() + assert_almost_equal(result_sum, expected_sum * num_devs) + + # list + vals = [[rand_ndarray(shape, stype).copyto(devs[i]) for i in range(num_devs)]] * len(keys) + expected_sum = np.zeros(shape) + for v in vals[0]: + expected_sum += v.asnumpy() + + kv.push(keys, vals) + kv.pull(keys, out = vals) + for vv in vals: + result_sum = np.zeros(shape) + for v in vv: + result_sum += v.asnumpy() + assert_almost_equal(result_sum, expected_sum * num_devs) def updater(key, recv, local): """use updater: +=""" @@ -121,5 +157,6 @@ def test_get_type(): test_get_type() test_single_kv_pair() test_list_kv_pair() + test_sparse_aggregator() test_aggregator() test_updater() diff --git a/tests/python/unittest/test_module.py b/tests/python/unittest/test_module.py index 470312352b0e..d2a1f7fa3a3e 100644 --- a/tests/python/unittest/test_module.py +++ b/tests/python/unittest/test_module.py @@ -379,12 +379,12 @@ def test_shared_exec_group(exec_grp_shared, exec_grp_created, shared_arg_names=N def test_module_fm(): mx.random.seed(11) rnd.seed(11) - def fm_model(k, feature_dim, storage_type='default'): - initializer = mx.initializer.Normal(sigma=0.01) - x = mx.symbol.Variable("data", storage_type=storage_type) - v = mx.symbol.Variable("v", shape=(feature_dim, k), init=initializer) + def fm_model(k, feature_dim): + norm = mx.initializer.Normal(sigma=0.01) + x = mx.symbol.Variable("data", storage_type='csr') + v = mx.symbol.Variable("v", shape=(feature_dim, k), init=norm, storage_type='row_sparse') - w1_weight = mx.symbol.var('w1_weight', shape=(feature_dim, 1), init=initializer) + w1_weight = mx.symbol.var('w1_weight', shape=(feature_dim, 1), init=norm) w1 = mx.symbol.dot(x, w1_weight) v_s = mx.symbol.sum(data=mx.symbol.square(data=v), axis=1) @@ -400,25 +400,26 @@ def fm_model(k, feature_dim, storage_type='default'): model = mx.symbol.LinearRegressionOutput(data=model, label=y, name="out") return model + # model ctx = default_context() k = 5 feature_dim = 20 - model = fm_model(k, feature_dim, 'csr') + model = fm_model(k, feature_dim) + # data iter num_batches = 8 batch_size = 25 + num_samples = batch_size * num_batches import scipy.sparse as sp - scipy_data = sp.rand(num_batches * batch_size, feature_dim, - density=0.5, format='csr') - dns_label = mx.nd.ones((num_batches * batch_size,1)) - csr_data = mx.sparse_nd.csr(scipy_data.data, scipy_data.indptr, scipy_data.indices, - (num_batches * batch_size, feature_dim)) - data = csr_data - - train_iter = mx.io.NDArrayIter(data=data, - label={'out_label':dns_label}, + # generate some random scipy csr data + csr_sp = sp.rand(num_samples, feature_dim, density=0.5, format='csr') + csr_nd = mx.sparse_nd.csr(csr_sp.data, csr_sp.indptr, csr_sp.indices, + (num_samples, feature_dim)) + label = mx.nd.ones((num_samples,1)) + # the alternative is to use LibSVMIter + train_iter = mx.io.NDArrayIter(data=csr_nd, + label={'out_label':label}, batch_size=batch_size) - # create module mod = mx.mod.Module(symbol=model, data_names=['data'], label_names=['out_label']) # allocate memory by given the input data and lable shapes @@ -429,9 +430,7 @@ def fm_model(k, feature_dim, storage_type='default'): mod.init_optimizer(optimizer='sgd') # use accuracy as the metric metric = mx.metric.create('MSE') - # train 5 epoch, i.e. going over the data iter one pass - storage_type_dict = {'v' : 'row_sparse'} - + # train 10 epoch for epoch in range(10): train_iter.reset() metric.reset() @@ -439,11 +438,10 @@ def fm_model(k, feature_dim, storage_type='default'): mod.forward(batch, is_train=True) # compute predictions mod.update_metric(metric, batch.label) # accumulate prediction accuracy mod.backward() # compute gradients - mod.update(storage_type_dict) # update parameters + mod.update() # update parameters # print('Epoch %d, Training %s' % (epoch, metric.get())) assert(metric.get()[1] < 0.2) - if __name__ == '__main__': test_module_dtype() test_module_input_grads()