From 16a6d7f98a35e881f3dd098a5794dc384e87607a Mon Sep 17 00:00:00 2001 From: Haibin Lin Date: Sat, 10 Jun 2017 11:26:48 -0700 Subject: [PATCH] serial elemwise sum impl (#80) update module kvstore interface add other missing params and functions revert some interface changes revert some more changes reomve explicit casting for gradients on kvstore update Comm interface update fm example Conflicts: python/mxnet/model.py python/mxnet/ndarray.py --- python/mxnet/model.py | 24 +--- python/mxnet/module/base_module.py | 12 +- python/mxnet/module/bucketing_module.py | 4 +- python/mxnet/module/module.py | 3 +- python/mxnet/module/python_module.py | 2 +- python/mxnet/module/sequential_module.py | 4 +- python/mxnet/ndarray.py | 6 +- python/mxnet/sparse_ndarray.py | 2 +- src/kvstore/comm.h | 161 +++++++++++++++++++---- src/kvstore/kvstore_dist.h | 2 +- src/kvstore/kvstore_local.h | 8 +- tests/python/unittest/test_kvstore.py | 45 ++++++- tests/python/unittest/test_module.py | 40 +++--- 13 files changed, 222 insertions(+), 91 deletions(-) diff --git a/python/mxnet/model.py b/python/mxnet/model.py index 0259a97d4594..aecf63c86b45 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -85,12 +85,10 @@ def _initialize_kvstore(kvstore, param_arrays, arg_params, param_names, if update_on_kvstore: kvstore.pull(idx, param_on_devs, priority=-idx) -def _update_params_on_kvstore(param_arrays, grad_arrays, kvstore, - stype_dict=None, param_names=None): +def _update_params_on_kvstore(param_arrays, grad_arrays, kvstore, param_names=None): """Perform update of param_arrays from grad_arrays on kvstore. If `param_names` is None or kvstore doesn't have a `name2idx` dictionary, the index of a param is determined by the order it appears in `param_arrays`. """ - stype_dict = {} if stype_dict is None else stype_dict for i, pair in enumerate(zip(param_arrays, grad_arrays)): arg_list, grad_list = pair if grad_list[0] is None: @@ -99,31 +97,18 @@ def _update_params_on_kvstore(param_arrays, grad_arrays, kvstore, if param_names is not None: name = param_names[i] index = index if name not in kvstore.name2idx else kvstore.name2idx[name] - # cast storage type if stype doesn't match - if name in stype_dict: - for j, grad in enumerate(grad_list): - stype = stype_dict[name] - if grad_list[j].storage_type != stype: - grad_list[j] = nd.cast_storage(grad, stype) # push gradient, priority is negative index kvstore.push(index, grad_list, priority=-index) # pull back the weights kvstore.pull(index, arg_list, priority=-index) def _update_params(param_arrays, grad_arrays, updater, num_device, - kvstore=None, stype_dict=None, param_names=None): + kvstore=None, param_names=None): """Perform update of param_arrays from grad_arrays not on kvstore.""" - stype_dict = {} if stype_dict is None else stype_dict for i, pair in enumerate(zip(param_arrays, grad_arrays)): arg_list, grad_list = pair if grad_list[0] is None: continue - # cast storage type if stype doesn't match - if param_names is not None and param_names[i] in stype_dict: - for j, grad in enumerate(grad_list): - stype = stype_dict[param_names[i]] - if grad_list[j].storage_type != stype: - grad_list[j] = nd.cast_storage(grad, stype) index = i if kvstore: if param_names is not None: @@ -136,8 +121,11 @@ def _update_params(param_arrays, grad_arrays, updater, num_device, for k, p in enumerate(zip(arg_list, grad_list)): # faked an index here, to make optimizer create diff # state for the same index but on diff devs, TODO(mli) - # use a better solution latter + # use a better solution later w, g = p + # cast storage type if stype doesn't match + if g.storage_type != w.storage_type: + g = nd.cast_storage(g, w.storage_type) updater(index*num_device+k, g, w) diff --git a/python/mxnet/module/base_module.py b/python/mxnet/module/base_module.py index c78daa1137c8..820841087a9c 100644 --- a/python/mxnet/module/base_module.py +++ b/python/mxnet/module/base_module.py @@ -849,17 +849,9 @@ def get_input_grads(self, merge_multi_context=True): """ raise NotImplementedError() - def update(self, storage_type_dict=None): + def update(self): """Updates parameters according to the installed optimizer and the gradients computed - in the previous forward-backward batch. The storage type of parameters is casted according - to `storage_type_dict`, if provided. - - Parameters - ---------- - storage_type_dict: dict of str to str - Defaults to ``None``. Desired storage types of parameters for parameter update. If the - parameter gradient is not of desired storage type, its storage type will be casted - before the update. + in the previous forward-backward batch. Examples -------- diff --git a/python/mxnet/module/bucketing_module.py b/python/mxnet/module/bucketing_module.py index ae10e8e401d0..11922ddafb56 100644 --- a/python/mxnet/module/bucketing_module.py +++ b/python/mxnet/module/bucketing_module.py @@ -399,13 +399,13 @@ def backward(self, out_grads=None): assert self.binded and self.params_initialized self._curr_module.backward(out_grads=out_grads) - def update(self, storage_type_dict=None): + def update(self): """Updates parameters according to installed optimizer and the gradient computed in the previous forward-backward cycle. """ assert self.binded and self.params_initialized and self.optimizer_initialized self._params_dirty = True - self._curr_module.update(storage_type_dict=storage_type_dict) + self._curr_module.update() def get_outputs(self, merge_multi_context=True): """Gets outputs from a previous forward computation. diff --git a/python/mxnet/module/module.py b/python/mxnet/module/module.py index a0eb19dafccc..26221078cee1 100644 --- a/python/mxnet/module/module.py +++ b/python/mxnet/module/module.py @@ -562,7 +562,7 @@ def backward(self, out_grads=None): assert self.binded and self.params_initialized self._exec_group.backward(out_grads=out_grads) - def update(self, storage_type_dict=None): + def update(self): """Updates parameters according to the installed optimizer and the gradients computed in the previous forward-backward batch. @@ -577,7 +577,6 @@ def update(self, storage_type_dict=None): _update_params_on_kvstore(self._exec_group.param_arrays, self._exec_group.grad_arrays, self._kvstore, - stype_dict=storage_type_dict, param_names=self._param_names) else: _update_params(self._exec_group.param_arrays, diff --git a/python/mxnet/module/python_module.py b/python/mxnet/module/python_module.py index 82dcb06aa020..f46ea280aaff 100644 --- a/python/mxnet/module/python_module.py +++ b/python/mxnet/module/python_module.py @@ -110,7 +110,7 @@ def init_params(self, initializer=Uniform(0.01), arg_params=None, aux_params=Non """ pass - def update(self, storage_type_dict=None): + def update(self): """Updates parameters according to the installed optimizer and the gradients computed in the previous forward-backward batch. Currently we do nothing here. Subclass should override this method if contains parameters. diff --git a/python/mxnet/module/sequential_module.py b/python/mxnet/module/sequential_module.py index 383286642e0c..21e30fb3b0ce 100644 --- a/python/mxnet/module/sequential_module.py +++ b/python/mxnet/module/sequential_module.py @@ -344,14 +344,14 @@ def backward(self, out_grads=None): out_grads = module.get_input_grads() - def update(self, storage_type_dict=None): + def update(self): """Updates parameters according to installed optimizer and the gradient computed in the previous forward-backward cycle. """ assert self.binded and self.params_initialized and self.optimizer_initialized for module in self._modules: - module.update(storage_type_dict=storage_type_dict) + module.update() def get_outputs(self, merge_multi_context=True): """Gets outputs from a previous forward computation. diff --git a/python/mxnet/ndarray.py b/python/mxnet/ndarray.py index 8e8d3ffebbd4..6167369110d7 100644 --- a/python/mxnet/ndarray.py +++ b/python/mxnet/ndarray.py @@ -1022,7 +1022,7 @@ def empty(shape, ctx=None, dtype=mx_real_t): ctx = Context.default_ctx return NDArray(handle=_new_alloc_handle(shape, ctx, False, dtype)) -def zeros(shape, ctx=None, dtype=mx_real_t, **kwargs): +def zeros(shape, ctx=None, dtype=None, **kwargs): """Returns a new array filled with all zeros, with the given shape and type. Parameters @@ -1053,11 +1053,12 @@ def zeros(shape, ctx=None, dtype=mx_real_t, **kwargs): # pylint: disable= unused-argument if ctx is None: ctx = Context.default_ctx + dtype = mx_real_t if dtype is None else dtype # pylint: disable= no-member, protected-access return _internal._zeros(shape=shape, ctx=ctx, dtype=dtype, **kwargs) # pylint: enable= no-member, protected-access -def ones(shape, ctx=None, dtype=mx_real_t, **kwargs): +def ones(shape, ctx=None, dtype=None, **kwargs): """Returns a new array filled with all ones, with the given shape and type. Parameters @@ -1089,6 +1090,7 @@ def ones(shape, ctx=None, dtype=mx_real_t, **kwargs): # pylint: disable= unused-argument if ctx is None: ctx = Context.default_ctx + dtype = mx_real_t if dtype is None else dtype # pylint: disable= no-member, protected-access return _internal._ones(shape=shape, ctx=ctx, dtype=dtype, **kwargs) # pylint: enable= no-member, protected-access diff --git a/python/mxnet/sparse_ndarray.py b/python/mxnet/sparse_ndarray.py index bc06fc1d1113..fe3239ae0bfa 100644 --- a/python/mxnet/sparse_ndarray.py +++ b/python/mxnet/sparse_ndarray.py @@ -600,7 +600,7 @@ def zeros(storage_type, shape, ctx=None, dtype=None, aux_types=None, **kwargs): array([[ 0., 0.]], dtype=float16) """ if storage_type == 'default': - return ndarray.zeros(shape, ctx, dtype, **kwargs) + return ndarray.zeros(shape, ctx=ctx, dtype=dtype, **kwargs) if ctx is None: ctx = Context.default_ctx dtype = mx_real_t if dtype is None else dtype diff --git a/src/kvstore/comm.h b/src/kvstore/comm.h index 1197d4ef3edb..e1ab5c9557e0 100644 --- a/src/kvstore/comm.h +++ b/src/kvstore/comm.h @@ -29,9 +29,10 @@ class Comm { } virtual ~Comm() { } /** - * \brief init key with the data shape + * \brief init key with the data shape and storage shape */ - virtual void Init(int key, const TShape& shape, int dtype = mshadow::kFloat32) = 0; + virtual void Init(int key, const NDArrayStorageType stype, + const TShape& shape, int dtype = mshadow::kFloat32) = 0; /** * \brief returns src[0] + .. + src[src.size()-1] */ @@ -67,8 +68,13 @@ class CommCPU : public Comm { } virtual ~CommCPU() { } - void Init(int key, const TShape& shape, int type = mshadow::kFloat32) override { - merge_buf_[key].merged = NDArray(shape, pinned_ctx_, false, type); + void Init(int key, const NDArrayStorageType stype, const TShape& shape, + int type = mshadow::kFloat32) override { + if (stype == kDefaultStorage) { + merge_buf_[key].merged = NDArray(shape, pinned_ctx_, false, type); + } else { + merge_buf_[key].merged = NDArray(stype, shape, pinned_ctx_, true, type); + } } const NDArray& Reduce(int key, const std::vector& src, @@ -78,29 +84,56 @@ class CommCPU : public Comm { if (src.size() == 1) { return src[0]; } - std::vector const_vars(src.size() - 1); - std::vector reduce(src.size()); auto& buf = merge_buf_[key]; - CopyFromTo(src[0], &buf.merged, priority); - reduce[0] = buf.merged; + if (buf.merged.storage_type() == kDefaultStorage) { + std::vector const_vars(src.size() - 1); + std::vector reduce(src.size()); + CopyFromTo(src[0], &buf.merged, priority); + reduce[0] = buf.merged; - if (buf.copy_buf.empty()) { - buf.copy_buf.resize(src.size()-1); - for (size_t j = 0; j < src.size() - 1; ++j) { - buf.copy_buf[j] = NDArray( - src[0].shape(), pinned_ctx_, false, src[0].dtype()); + if (buf.copy_buf.empty()) { + buf.copy_buf.resize(src.size()-1); + for (size_t j = 0; j < src.size() - 1; ++j) { + // allocate NDArray basd on storage type + buf.copy_buf[j] = NDArray( + src[0].shape(), pinned_ctx_, false, src[0].dtype()); + } + } + for (size_t i = 1; i < src.size(); ++i) { + CopyFromTo(src[i], &(buf.copy_buf[i-1]), priority); + reduce[i] = buf.copy_buf[i-1]; + const_vars[i-1] = reduce[i].var(); } - } - for (size_t i = 1; i < src.size(); ++i) { - CopyFromTo(src[i], &(buf.copy_buf[i-1]), priority); - reduce[i] = buf.copy_buf[i-1]; - const_vars[i-1] = reduce[i].var(); - } - Engine::Get()->PushSync([reduce, this](RunContext rctx) { - ReduceSumCPU(reduce); - }, Context::CPU(), const_vars, {reduce[0].var()}, - FnProperty::kCPUPrioritized, priority, PROFILER_MESSAGE("KVStoreReduce")); + Engine::Get()->PushSync([reduce, this](RunContext rctx) { + ReduceSumCPU(reduce); + }, Context::CPU(), const_vars, {reduce[0].var()}, + FnProperty::kCPUPrioritized, priority, PROFILER_MESSAGE("KVStoreReduce")); + + } else { + // buf.merged is a sparse ndarray. + std::vector const_vars(src.size()); + std::vector reduce(src.size()); + + if (buf.copy_buf.empty()) { + buf.copy_buf.resize(src.size()); + for (size_t j = 0; j < src.size(); ++j) { + buf.copy_buf[j] = NDArray( + src[0].storage_type(), src[0].shape(), pinned_ctx_, true, src[0].dtype()); + } + } + for (size_t i = 0; i < src.size(); ++i) { + CopyFromTo(src[i], &(buf.copy_buf[i]), priority); + reduce[i] = buf.copy_buf[i]; + const_vars[i] = reduce[i].var(); + } + auto result = buf.merged; + Engine::Get()->PushSync([reduce, result, this](RunContext rctx) { + NDArray out = result; + ReduceSumCPUEx(reduce, &out); + }, Context::CPU(), const_vars, {result.var()}, + FnProperty::kCPUPrioritized, priority, PROFILER_MESSAGE("KVStoreReduce")); + } return buf.merged; } @@ -133,6 +166,79 @@ class CommCPU : public Comm { }); } + // serial implementation of reduce sum for row sparse NDArray. + // TODO(haibin) use openmp kernel to parallelize the summation + inline void ReduceSumCPUEx(const std::vector &in, NDArray *out) { + using namespace rowsparse; + using namespace mshadow; + auto stype = out->storage_type(); + CHECK_EQ(stype, kRowSparseStorage) << "Unexpected storage type " << stype; + size_t total_num_rows = 0; + size_t num_in = in.size(); + // skip the ones with empty indices and values + std::vector skip(num_in, false); + // the values tensor of the inputs + MSHADOW_TYPE_SWITCH(out->dtype(), DType, { + MSHADOW_INT_TYPE_SWITCH(out->aux_type(kIdx), IType, { + std::vector> in_vals(num_in); + std::vector> in_indices(num_in); + // offset to the values tensor of all inputs + std::vector offsets(num_in, 0); + std::vector num_rows(num_in, 0); + for (size_t i = 0; i < num_in; i++) { + if (!in[i].storage_initialized()) { + skip[i] = true; + continue; + } + auto size = in[i].aux_shape(kIdx).Size(); + num_rows[i] = size; + total_num_rows += size; + in_vals[i] = in[i].data().FlatTo2D(); + in_indices[i] = in[i].aux_data(kIdx).FlatTo1D(); + } + std::vector indices; + indices.reserve(total_num_rows); + // gather indices from all inputs + for (size_t i = 0; i < num_in; i++) { + for (size_t j = 0; j < num_rows[i]; j++) { + indices.emplace_back(in_indices[i][j]); + } + } + CHECK_EQ(indices.size(), total_num_rows); + // dedup indices + std::sort(indices.begin(), indices.end()); + indices.resize(std::unique(indices.begin(), indices.end()) - indices.begin()); + // the one left are unique non-zero rows + size_t nnr = indices.size(); + // allocate memory for output + out->CheckAndAlloc({Shape1(nnr)}); + auto idx_data = out->aux_data(kIdx).FlatTo1D(); + auto val_data = out->data().FlatTo2D(); + + for (size_t i = 0; i < nnr; i++) { + // copy indices back + idx_data[i] = indices[i]; + bool zeros = true; + for (size_t j = 0; j < num_in; j++) { + if (skip[j]) continue; + size_t offset = offsets[j]; + if (offset < num_rows[j]) { + if (indices[i] == in_indices[j][offset]) { + if (zeros) { + Copy(val_data[i], in_vals[j][offset], nullptr); + zeros = false; + } else { + val_data[i] += in_vals[j][offset]; + } + offsets[j] += 1; + } + } + } + } + }); + }); + } + template inline static void ReduceSumCPU( const std::vector &dptr, size_t offset, index_t size) { @@ -216,8 +322,13 @@ class CommDevice : public Comm { virtual ~CommDevice() { } - void Init(int key, const TShape& shape, int dtype = mshadow::kFloat32) override { - sorted_key_attrs_.push_back(std::make_tuple(key, shape, dtype)); + void Init(int key, const NDArrayStorageType stype, const TShape& shape, + int dtype = mshadow::kFloat32) override { + if (stype == kDefaultStorage) { + sorted_key_attrs_.push_back(std::make_tuple(key, shape, dtype)); + } else { + LOG(FATAL) << "storage type " << stype << " not implemented for device yet"; + } } const NDArray& Reduce(int key, const std::vector& src, diff --git a/src/kvstore/kvstore_dist.h b/src/kvstore/kvstore_dist.h index 5f5a0cc67a64..62ec06c30fab 100644 --- a/src/kvstore/kvstore_dist.h +++ b/src/kvstore/kvstore_dist.h @@ -63,7 +63,7 @@ class KVStoreDist : public KVStoreLocal { const std::vector& values) override { CheckUnique(keys); for (size_t i = 0; i < keys.size(); ++i) { - comm_->Init(keys[i], values[i].shape(), values[i].dtype()); + comm_->Init(keys[i], values[i].storage_type(), values[i].shape(), values[i].dtype()); } if (get_rank() == 0) { Push_(keys, values, 0, false); diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h index caa57a20d46e..5506f2c76bb3 100644 --- a/src/kvstore/kvstore_local.h +++ b/src/kvstore/kvstore_local.h @@ -43,7 +43,7 @@ class KVStoreLocal : public KVStore { CHECK(local_.find(keys[i]) == local_.end()) << "duplicate init of key " << keys[i]; local_[keys[i]] = values[i].Copy(pinned_ctx_); - comm_->Init(keys[i], values[i].shape(), values[i].dtype()); + comm_->Init(keys[i], values[i].storage_type(), values[i].shape(), values[i].dtype()); } } @@ -67,7 +67,11 @@ class KVStoreLocal : public KVStore { } updater_(key, merged, &local); } else { - local = merged; + if (merged.storage_type() != local.storage_type()) { + local = merged.Copy(local.ctx()); + } else { + local = merged; + } } } } diff --git a/tests/python/unittest/test_kvstore.py b/tests/python/unittest/test_kvstore.py index dd8149d4822e..a64bfcae0868 100644 --- a/tests/python/unittest/test_kvstore.py +++ b/tests/python/unittest/test_kvstore.py @@ -1,19 +1,19 @@ # pylint: skip-file import mxnet as mx import numpy as np +from mxnet.test_utils import rand_ndarray, assert_almost_equal shape = (4, 4) keys = [5, 7, 11] -def init_kv(): +def init_kv(stype='default'): """init kv """ kv = mx.kv.create() # single - kv.init(3, mx.nd.zeros(shape)) + kv.init(3, mx.sparse_nd.zeros(stype, shape)) # list - kv.init(keys, [mx.nd.zeros(shape)] * len(keys)) + kv.init(keys, [mx.sparse_nd.zeros(stype, shape)] * len(keys)) return kv - def check_diff_to_scalar(A, x): """ assert A == x""" assert(np.sum(np.abs((A - x).asnumpy())) == 0) @@ -74,6 +74,42 @@ def test_aggregator(): for v in vv: check_diff_to_scalar(v, num_devs * 2.0) +def test_sparse_aggregator(): + """aggregate sparse ndarray on muliple devices""" + + stype = 'row_sparse' + kv = init_kv(stype) + + # devices + num_devs = 4 + devs = [mx.Context('cpu', i) for i in range(num_devs)] + + # single + vals = [rand_ndarray(shape, stype).copyto(devs[i]) for i in range(num_devs)] + expected_sum = np.zeros(shape) + for v in vals: + expected_sum += v.asnumpy() + + kv.push(3, vals) + kv.pull(3, out = vals) + result_sum = np.zeros(shape) + for v in vals: + result_sum += v.asnumpy() + assert_almost_equal(result_sum, expected_sum * num_devs) + + # list + vals = [[rand_ndarray(shape, stype).copyto(devs[i]) for i in range(num_devs)]] * len(keys) + expected_sum = np.zeros(shape) + for v in vals[0]: + expected_sum += v.asnumpy() + + kv.push(keys, vals) + kv.pull(keys, out = vals) + for vv in vals: + result_sum = np.zeros(shape) + for v in vv: + result_sum += v.asnumpy() + assert_almost_equal(result_sum, expected_sum * num_devs) def updater(key, recv, local): """use updater: +=""" @@ -121,5 +157,6 @@ def test_get_type(): test_get_type() test_single_kv_pair() test_list_kv_pair() + test_sparse_aggregator() test_aggregator() test_updater() diff --git a/tests/python/unittest/test_module.py b/tests/python/unittest/test_module.py index 470312352b0e..d2a1f7fa3a3e 100644 --- a/tests/python/unittest/test_module.py +++ b/tests/python/unittest/test_module.py @@ -379,12 +379,12 @@ def test_shared_exec_group(exec_grp_shared, exec_grp_created, shared_arg_names=N def test_module_fm(): mx.random.seed(11) rnd.seed(11) - def fm_model(k, feature_dim, storage_type='default'): - initializer = mx.initializer.Normal(sigma=0.01) - x = mx.symbol.Variable("data", storage_type=storage_type) - v = mx.symbol.Variable("v", shape=(feature_dim, k), init=initializer) + def fm_model(k, feature_dim): + norm = mx.initializer.Normal(sigma=0.01) + x = mx.symbol.Variable("data", storage_type='csr') + v = mx.symbol.Variable("v", shape=(feature_dim, k), init=norm, storage_type='row_sparse') - w1_weight = mx.symbol.var('w1_weight', shape=(feature_dim, 1), init=initializer) + w1_weight = mx.symbol.var('w1_weight', shape=(feature_dim, 1), init=norm) w1 = mx.symbol.dot(x, w1_weight) v_s = mx.symbol.sum(data=mx.symbol.square(data=v), axis=1) @@ -400,25 +400,26 @@ def fm_model(k, feature_dim, storage_type='default'): model = mx.symbol.LinearRegressionOutput(data=model, label=y, name="out") return model + # model ctx = default_context() k = 5 feature_dim = 20 - model = fm_model(k, feature_dim, 'csr') + model = fm_model(k, feature_dim) + # data iter num_batches = 8 batch_size = 25 + num_samples = batch_size * num_batches import scipy.sparse as sp - scipy_data = sp.rand(num_batches * batch_size, feature_dim, - density=0.5, format='csr') - dns_label = mx.nd.ones((num_batches * batch_size,1)) - csr_data = mx.sparse_nd.csr(scipy_data.data, scipy_data.indptr, scipy_data.indices, - (num_batches * batch_size, feature_dim)) - data = csr_data - - train_iter = mx.io.NDArrayIter(data=data, - label={'out_label':dns_label}, + # generate some random scipy csr data + csr_sp = sp.rand(num_samples, feature_dim, density=0.5, format='csr') + csr_nd = mx.sparse_nd.csr(csr_sp.data, csr_sp.indptr, csr_sp.indices, + (num_samples, feature_dim)) + label = mx.nd.ones((num_samples,1)) + # the alternative is to use LibSVMIter + train_iter = mx.io.NDArrayIter(data=csr_nd, + label={'out_label':label}, batch_size=batch_size) - # create module mod = mx.mod.Module(symbol=model, data_names=['data'], label_names=['out_label']) # allocate memory by given the input data and lable shapes @@ -429,9 +430,7 @@ def fm_model(k, feature_dim, storage_type='default'): mod.init_optimizer(optimizer='sgd') # use accuracy as the metric metric = mx.metric.create('MSE') - # train 5 epoch, i.e. going over the data iter one pass - storage_type_dict = {'v' : 'row_sparse'} - + # train 10 epoch for epoch in range(10): train_iter.reset() metric.reset() @@ -439,11 +438,10 @@ def fm_model(k, feature_dim, storage_type='default'): mod.forward(batch, is_train=True) # compute predictions mod.update_metric(metric, batch.label) # accumulate prediction accuracy mod.backward() # compute gradients - mod.update(storage_type_dict) # update parameters + mod.update() # update parameters # print('Epoch %d, Training %s' % (epoch, metric.get())) assert(metric.get()[1] < 0.2) - if __name__ == '__main__': test_module_dtype() test_module_input_grads()