Skip to content

Commit

Permalink
serial elemwise sum impl (apache#80)
Browse files Browse the repository at this point in the history
update module kvstore interface

add other missing params and functions

revert some interface changes

revert some more changes

reomve explicit casting for gradients on kvstore

update Comm interface

update fm example

Conflicts:
	python/mxnet/model.py
	python/mxnet/ndarray.py
  • Loading branch information
eric-haibin-lin authored Jun 10, 2017
1 parent c8d3742 commit 16a6d7f
Show file tree
Hide file tree
Showing 13 changed files with 222 additions and 91 deletions.
24 changes: 6 additions & 18 deletions python/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,10 @@ def _initialize_kvstore(kvstore, param_arrays, arg_params, param_names,
if update_on_kvstore:
kvstore.pull(idx, param_on_devs, priority=-idx)

def _update_params_on_kvstore(param_arrays, grad_arrays, kvstore,
stype_dict=None, param_names=None):
def _update_params_on_kvstore(param_arrays, grad_arrays, kvstore, param_names=None):
"""Perform update of param_arrays from grad_arrays on kvstore.
If `param_names` is None or kvstore doesn't have a `name2idx` dictionary,
the index of a param is determined by the order it appears in `param_arrays`. """
stype_dict = {} if stype_dict is None else stype_dict
for i, pair in enumerate(zip(param_arrays, grad_arrays)):
arg_list, grad_list = pair
if grad_list[0] is None:
Expand All @@ -99,31 +97,18 @@ def _update_params_on_kvstore(param_arrays, grad_arrays, kvstore,
if param_names is not None:
name = param_names[i]
index = index if name not in kvstore.name2idx else kvstore.name2idx[name]
# cast storage type if stype doesn't match
if name in stype_dict:
for j, grad in enumerate(grad_list):
stype = stype_dict[name]
if grad_list[j].storage_type != stype:
grad_list[j] = nd.cast_storage(grad, stype)
# push gradient, priority is negative index
kvstore.push(index, grad_list, priority=-index)
# pull back the weights
kvstore.pull(index, arg_list, priority=-index)

def _update_params(param_arrays, grad_arrays, updater, num_device,
kvstore=None, stype_dict=None, param_names=None):
kvstore=None, param_names=None):
"""Perform update of param_arrays from grad_arrays not on kvstore."""
stype_dict = {} if stype_dict is None else stype_dict
for i, pair in enumerate(zip(param_arrays, grad_arrays)):
arg_list, grad_list = pair
if grad_list[0] is None:
continue
# cast storage type if stype doesn't match
if param_names is not None and param_names[i] in stype_dict:
for j, grad in enumerate(grad_list):
stype = stype_dict[param_names[i]]
if grad_list[j].storage_type != stype:
grad_list[j] = nd.cast_storage(grad, stype)
index = i
if kvstore:
if param_names is not None:
Expand All @@ -136,8 +121,11 @@ def _update_params(param_arrays, grad_arrays, updater, num_device,
for k, p in enumerate(zip(arg_list, grad_list)):
# faked an index here, to make optimizer create diff
# state for the same index but on diff devs, TODO(mli)
# use a better solution latter
# use a better solution later
w, g = p
# cast storage type if stype doesn't match
if g.storage_type != w.storage_type:
g = nd.cast_storage(g, w.storage_type)
updater(index*num_device+k, g, w)


Expand Down
12 changes: 2 additions & 10 deletions python/mxnet/module/base_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,17 +849,9 @@ def get_input_grads(self, merge_multi_context=True):
"""
raise NotImplementedError()

def update(self, storage_type_dict=None):
def update(self):
"""Updates parameters according to the installed optimizer and the gradients computed
in the previous forward-backward batch. The storage type of parameters is casted according
to `storage_type_dict`, if provided.
Parameters
----------
storage_type_dict: dict of str to str
Defaults to ``None``. Desired storage types of parameters for parameter update. If the
parameter gradient is not of desired storage type, its storage type will be casted
before the update.
in the previous forward-backward batch.
Examples
--------
Expand Down
4 changes: 2 additions & 2 deletions python/mxnet/module/bucketing_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,13 +399,13 @@ def backward(self, out_grads=None):
assert self.binded and self.params_initialized
self._curr_module.backward(out_grads=out_grads)

def update(self, storage_type_dict=None):
def update(self):
"""Updates parameters according to installed optimizer and the gradient computed
in the previous forward-backward cycle.
"""
assert self.binded and self.params_initialized and self.optimizer_initialized
self._params_dirty = True
self._curr_module.update(storage_type_dict=storage_type_dict)
self._curr_module.update()

def get_outputs(self, merge_multi_context=True):
"""Gets outputs from a previous forward computation.
Expand Down
3 changes: 1 addition & 2 deletions python/mxnet/module/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ def backward(self, out_grads=None):
assert self.binded and self.params_initialized
self._exec_group.backward(out_grads=out_grads)

def update(self, storage_type_dict=None):
def update(self):
"""Updates parameters according to the installed optimizer and the gradients computed
in the previous forward-backward batch.
Expand All @@ -577,7 +577,6 @@ def update(self, storage_type_dict=None):
_update_params_on_kvstore(self._exec_group.param_arrays,
self._exec_group.grad_arrays,
self._kvstore,
stype_dict=storage_type_dict,
param_names=self._param_names)
else:
_update_params(self._exec_group.param_arrays,
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/module/python_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def init_params(self, initializer=Uniform(0.01), arg_params=None, aux_params=Non
"""
pass

def update(self, storage_type_dict=None):
def update(self):
"""Updates parameters according to the installed optimizer and the gradients computed
in the previous forward-backward batch. Currently we do nothing here. Subclass should
override this method if contains parameters.
Expand Down
4 changes: 2 additions & 2 deletions python/mxnet/module/sequential_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,14 +344,14 @@ def backward(self, out_grads=None):

out_grads = module.get_input_grads()

def update(self, storage_type_dict=None):
def update(self):
"""Updates parameters according to installed optimizer and the gradient computed
in the previous forward-backward cycle.
"""
assert self.binded and self.params_initialized and self.optimizer_initialized

for module in self._modules:
module.update(storage_type_dict=storage_type_dict)
module.update()

def get_outputs(self, merge_multi_context=True):
"""Gets outputs from a previous forward computation.
Expand Down
6 changes: 4 additions & 2 deletions python/mxnet/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1022,7 +1022,7 @@ def empty(shape, ctx=None, dtype=mx_real_t):
ctx = Context.default_ctx
return NDArray(handle=_new_alloc_handle(shape, ctx, False, dtype))

def zeros(shape, ctx=None, dtype=mx_real_t, **kwargs):
def zeros(shape, ctx=None, dtype=None, **kwargs):
"""Returns a new array filled with all zeros, with the given shape and type.
Parameters
Expand Down Expand Up @@ -1053,11 +1053,12 @@ def zeros(shape, ctx=None, dtype=mx_real_t, **kwargs):
# pylint: disable= unused-argument
if ctx is None:
ctx = Context.default_ctx
dtype = mx_real_t if dtype is None else dtype
# pylint: disable= no-member, protected-access
return _internal._zeros(shape=shape, ctx=ctx, dtype=dtype, **kwargs)
# pylint: enable= no-member, protected-access

def ones(shape, ctx=None, dtype=mx_real_t, **kwargs):
def ones(shape, ctx=None, dtype=None, **kwargs):
"""Returns a new array filled with all ones, with the given shape and type.
Parameters
Expand Down Expand Up @@ -1089,6 +1090,7 @@ def ones(shape, ctx=None, dtype=mx_real_t, **kwargs):
# pylint: disable= unused-argument
if ctx is None:
ctx = Context.default_ctx
dtype = mx_real_t if dtype is None else dtype
# pylint: disable= no-member, protected-access
return _internal._ones(shape=shape, ctx=ctx, dtype=dtype, **kwargs)
# pylint: enable= no-member, protected-access
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/sparse_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,7 @@ def zeros(storage_type, shape, ctx=None, dtype=None, aux_types=None, **kwargs):
array([[ 0., 0.]], dtype=float16)
"""
if storage_type == 'default':
return ndarray.zeros(shape, ctx, dtype, **kwargs)
return ndarray.zeros(shape, ctx=ctx, dtype=dtype, **kwargs)
if ctx is None:
ctx = Context.default_ctx
dtype = mx_real_t if dtype is None else dtype
Expand Down
161 changes: 136 additions & 25 deletions src/kvstore/comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@ class Comm {
}
virtual ~Comm() { }
/**
* \brief init key with the data shape
* \brief init key with the data shape and storage shape
*/
virtual void Init(int key, const TShape& shape, int dtype = mshadow::kFloat32) = 0;
virtual void Init(int key, const NDArrayStorageType stype,
const TShape& shape, int dtype = mshadow::kFloat32) = 0;
/**
* \brief returns src[0] + .. + src[src.size()-1]
*/
Expand Down Expand Up @@ -67,8 +68,13 @@ class CommCPU : public Comm {
}
virtual ~CommCPU() { }

void Init(int key, const TShape& shape, int type = mshadow::kFloat32) override {
merge_buf_[key].merged = NDArray(shape, pinned_ctx_, false, type);
void Init(int key, const NDArrayStorageType stype, const TShape& shape,
int type = mshadow::kFloat32) override {
if (stype == kDefaultStorage) {
merge_buf_[key].merged = NDArray(shape, pinned_ctx_, false, type);
} else {
merge_buf_[key].merged = NDArray(stype, shape, pinned_ctx_, true, type);
}
}

const NDArray& Reduce(int key, const std::vector<NDArray>& src,
Expand All @@ -78,29 +84,56 @@ class CommCPU : public Comm {
if (src.size() == 1) {
return src[0];
}
std::vector<Engine::VarHandle> const_vars(src.size() - 1);
std::vector<NDArray> reduce(src.size());
auto& buf = merge_buf_[key];
CopyFromTo(src[0], &buf.merged, priority);
reduce[0] = buf.merged;
if (buf.merged.storage_type() == kDefaultStorage) {
std::vector<Engine::VarHandle> const_vars(src.size() - 1);
std::vector<NDArray> reduce(src.size());
CopyFromTo(src[0], &buf.merged, priority);
reduce[0] = buf.merged;

if (buf.copy_buf.empty()) {
buf.copy_buf.resize(src.size()-1);
for (size_t j = 0; j < src.size() - 1; ++j) {
buf.copy_buf[j] = NDArray(
src[0].shape(), pinned_ctx_, false, src[0].dtype());
if (buf.copy_buf.empty()) {
buf.copy_buf.resize(src.size()-1);
for (size_t j = 0; j < src.size() - 1; ++j) {
// allocate NDArray basd on storage type
buf.copy_buf[j] = NDArray(
src[0].shape(), pinned_ctx_, false, src[0].dtype());
}
}
for (size_t i = 1; i < src.size(); ++i) {
CopyFromTo(src[i], &(buf.copy_buf[i-1]), priority);
reduce[i] = buf.copy_buf[i-1];
const_vars[i-1] = reduce[i].var();
}
}
for (size_t i = 1; i < src.size(); ++i) {
CopyFromTo(src[i], &(buf.copy_buf[i-1]), priority);
reduce[i] = buf.copy_buf[i-1];
const_vars[i-1] = reduce[i].var();
}

Engine::Get()->PushSync([reduce, this](RunContext rctx) {
ReduceSumCPU(reduce);
}, Context::CPU(), const_vars, {reduce[0].var()},
FnProperty::kCPUPrioritized, priority, PROFILER_MESSAGE("KVStoreReduce"));
Engine::Get()->PushSync([reduce, this](RunContext rctx) {
ReduceSumCPU(reduce);
}, Context::CPU(), const_vars, {reduce[0].var()},
FnProperty::kCPUPrioritized, priority, PROFILER_MESSAGE("KVStoreReduce"));

} else {
// buf.merged is a sparse ndarray.
std::vector<Engine::VarHandle> const_vars(src.size());
std::vector<NDArray> reduce(src.size());

if (buf.copy_buf.empty()) {
buf.copy_buf.resize(src.size());
for (size_t j = 0; j < src.size(); ++j) {
buf.copy_buf[j] = NDArray(
src[0].storage_type(), src[0].shape(), pinned_ctx_, true, src[0].dtype());
}
}
for (size_t i = 0; i < src.size(); ++i) {
CopyFromTo(src[i], &(buf.copy_buf[i]), priority);
reduce[i] = buf.copy_buf[i];
const_vars[i] = reduce[i].var();
}
auto result = buf.merged;
Engine::Get()->PushSync([reduce, result, this](RunContext rctx) {
NDArray out = result;
ReduceSumCPUEx(reduce, &out);
}, Context::CPU(), const_vars, {result.var()},
FnProperty::kCPUPrioritized, priority, PROFILER_MESSAGE("KVStoreReduce"));
}

return buf.merged;
}
Expand Down Expand Up @@ -133,6 +166,79 @@ class CommCPU : public Comm {
});
}

// serial implementation of reduce sum for row sparse NDArray.
// TODO(haibin) use openmp kernel to parallelize the summation
inline void ReduceSumCPUEx(const std::vector<NDArray> &in, NDArray *out) {
using namespace rowsparse;
using namespace mshadow;
auto stype = out->storage_type();
CHECK_EQ(stype, kRowSparseStorage) << "Unexpected storage type " << stype;
size_t total_num_rows = 0;
size_t num_in = in.size();
// skip the ones with empty indices and values
std::vector<bool> skip(num_in, false);
// the values tensor of the inputs
MSHADOW_TYPE_SWITCH(out->dtype(), DType, {
MSHADOW_INT_TYPE_SWITCH(out->aux_type(kIdx), IType, {
std::vector<Tensor<cpu, 2, DType>> in_vals(num_in);
std::vector<Tensor<cpu, 1, IType>> in_indices(num_in);
// offset to the values tensor of all inputs
std::vector<size_t> offsets(num_in, 0);
std::vector<size_t> num_rows(num_in, 0);
for (size_t i = 0; i < num_in; i++) {
if (!in[i].storage_initialized()) {
skip[i] = true;
continue;
}
auto size = in[i].aux_shape(kIdx).Size();
num_rows[i] = size;
total_num_rows += size;
in_vals[i] = in[i].data().FlatTo2D<cpu, DType>();
in_indices[i] = in[i].aux_data(kIdx).FlatTo1D<cpu, IType>();
}
std::vector<IType> indices;
indices.reserve(total_num_rows);
// gather indices from all inputs
for (size_t i = 0; i < num_in; i++) {
for (size_t j = 0; j < num_rows[i]; j++) {
indices.emplace_back(in_indices[i][j]);
}
}
CHECK_EQ(indices.size(), total_num_rows);
// dedup indices
std::sort(indices.begin(), indices.end());
indices.resize(std::unique(indices.begin(), indices.end()) - indices.begin());
// the one left are unique non-zero rows
size_t nnr = indices.size();
// allocate memory for output
out->CheckAndAlloc({Shape1(nnr)});
auto idx_data = out->aux_data(kIdx).FlatTo1D<cpu, IType>();
auto val_data = out->data().FlatTo2D<cpu, DType>();

for (size_t i = 0; i < nnr; i++) {
// copy indices back
idx_data[i] = indices[i];
bool zeros = true;
for (size_t j = 0; j < num_in; j++) {
if (skip[j]) continue;
size_t offset = offsets[j];
if (offset < num_rows[j]) {
if (indices[i] == in_indices[j][offset]) {
if (zeros) {
Copy(val_data[i], in_vals[j][offset], nullptr);
zeros = false;
} else {
val_data[i] += in_vals[j][offset];
}
offsets[j] += 1;
}
}
}
}
});
});
}

template<typename DType>
inline static void ReduceSumCPU(
const std::vector<DType*> &dptr, size_t offset, index_t size) {
Expand Down Expand Up @@ -216,8 +322,13 @@ class CommDevice : public Comm {

virtual ~CommDevice() { }

void Init(int key, const TShape& shape, int dtype = mshadow::kFloat32) override {
sorted_key_attrs_.push_back(std::make_tuple(key, shape, dtype));
void Init(int key, const NDArrayStorageType stype, const TShape& shape,
int dtype = mshadow::kFloat32) override {
if (stype == kDefaultStorage) {
sorted_key_attrs_.push_back(std::make_tuple(key, shape, dtype));
} else {
LOG(FATAL) << "storage type " << stype << " not implemented for device yet";
}
}

const NDArray& Reduce(int key, const std::vector<NDArray>& src,
Expand Down
2 changes: 1 addition & 1 deletion src/kvstore/kvstore_dist.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class KVStoreDist : public KVStoreLocal {
const std::vector<NDArray>& values) override {
CheckUnique(keys);
for (size_t i = 0; i < keys.size(); ++i) {
comm_->Init(keys[i], values[i].shape(), values[i].dtype());
comm_->Init(keys[i], values[i].storage_type(), values[i].shape(), values[i].dtype());
}
if (get_rank() == 0) {
Push_(keys, values, 0, false);
Expand Down
Loading

0 comments on commit 16a6d7f

Please sign in to comment.