Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
update module API for other submodules
Browse files Browse the repository at this point in the history
update stypes in kvstore after refactoring

change type of size from size_t to int64_t

add sparse linear regression example

remove sparse_pull_dict from module

fix init_optim for seq_module. update sparse example
  • Loading branch information
eric-haibin-lin committed Jul 26, 2017
1 parent bdd7de7 commit 60cac0b
Show file tree
Hide file tree
Showing 10 changed files with 250 additions and 57 deletions.
15 changes: 15 additions & 0 deletions example/sparse/get_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# pylint: skip-file
import os, gzip
import pickle as pickle
import sys

def get_libsvm_data(data_dir, data_name, url, data_origin_name):
if not os.path.isdir(data_dir):
os.system("mkdir " + data_dir)
os.chdir(data_dir)
if (not os.path.exists(data_name)):
import urllib
zippath = os.path.join(data_dir, data_origin_name)
urllib.urlretrieve(url, zippath)
os.system("bzip2 -d %r" % data_origin_name)
os.chdir("..")
178 changes: 178 additions & 0 deletions example/sparse/linear_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import mxnet as mx
from mxnet.test_utils import *
from get_data import get_libsvm_data
import time
import argparse
import os

parser = argparse.ArgumentParser(description="Run sparse linear regression " \
"with distributed kvstore",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--profiler', type=int, default=0,
help='whether to use profiler')
parser.add_argument('--num-epoch', type=int, default=1,
help='number of epochs to train')
parser.add_argument('--batch-size', type=int, default=512,
help='number of examples per batch')
parser.add_argument('--num-batch', type=int, default=99999999,
help='number of batches per epoch')
parser.add_argument('--dummy-iter', type=int, default=0,
help='whether to use dummy iterator to exclude io cost')
parser.add_argument('--kvstore', type=str, default='dist_sync',
help='what kvstore to use [local, dist_sync, etc]')
parser.add_argument('--log-level', type=str, default='debug',
help='logging level [debug, info, error]')
parser.add_argument('--dataset', type=str, default='avazu',
help='what test dataset to use')

class DummyIter(mx.io.DataIter):
"A dummy iterator that always return the same batch, used for speed testing"
def __init__(self, real_iter):
super(DummyIter, self).__init__()
self.real_iter = real_iter
self.provide_data = real_iter.provide_data
self.provide_label = real_iter.provide_label
self.batch_size = real_iter.batch_size

for batch in real_iter:
self.the_batch = batch
break

def __iter__(self):
return self

def next(self):
return self.the_batch

# testing dataset sources
avazu = {
'data_name': 'avazu-app.t',
'data_origin_name': 'avazu-app.t.bz2',
'url': "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/avazu-app.t.bz2",
'feature_dim': 1000000,
}

kdda = {
'data_name': 'kdda.t',
'data_origin_name': 'kdda.t.bz2',
'url': "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/kdda.t.bz2",
'feature_dim': 20216830,
}

datasets = { 'kdda' : kdda, 'avazu' : avazu }

def regression_model(feature_dim):
initializer = mx.initializer.Normal()
x = mx.symbol.Variable("data", stype='csr')
norm_init = mx.initializer.Normal(sigma=0.01)
v = mx.symbol.Variable("v", shape=(feature_dim, 1), init=norm_init, stype='row_sparse')
embed = mx.symbol.dot(x, v)
y = mx.symbol.Variable("softmax_label")
model = mx.symbol.LinearRegressionOutput(data=embed, label=y, name="out")
return model

if __name__ == '__main__':

# arg parser
args = parser.parse_args()
num_epoch = args.num_epoch
num_batch = args.num_batch
kvstore = args.kvstore
profiler = args.profiler > 0
batch_size = args.batch_size
dummy_iter = args.dummy_iter
dataset = args.dataset
log_level = args.log_level

# create kvstore
kv = mx.kvstore.create(kvstore)
rank = kv.rank
num_worker = kv.num_workers

# only print log for rank 0 worker
import logging
if rank != 0:
log_level = logging.ERROR
elif log_level == 'DEBUG':
log_level = logging.DEBUG
else:
log_level = logging.INFO
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=log_level, format=head)

# dataset
assert(dataset in datasets), "unknown dataset " + dataset
metadata = datasets[dataset]
feature_dim = metadata['feature_dim']
if logging:
logging.debug('preparing data ... ')
data_dir = os.path.join(os.getcwd(), 'data')
path = os.path.join(data_dir, metadata['data_name'])
if not os.path.exists(path):
get_libsvm_data(data_dir, metadata['data_name'], metadata['url'],
metadata['data_origin_name'])
assert os.path.exists(path)

# data iterator
train_data = mx.io.LibSVMIter(data_libsvm=path, data_shape=(feature_dim,),
batch_size=batch_size, num_parts=num_worker,
part_index=rank)
if dummy_iter:
train_data = DummyIter(train_data)

# model
model = regression_model(feature_dim)

# module
mod = mx.mod.Module(symbol=model, data_names=['data'], label_names=['softmax_label'])
mod.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label)
mod.init_params(initializer=mx.init.Uniform(scale=.1))
sgd = mx.optimizer.SGD(momentum=0.0, clip_gradient=5.0,
learning_rate=0.1, rescale_grad=1.0/batch_size/num_worker)
mod.init_optimizer(optimizer=sgd, kvstore=kv)
# use accuracy as the metric
metric = mx.metric.create('MSE')

# start profiler
if profiler:
import random
name = 'profile_output_' + str(num_worker) + '.json'
mx.profiler.profiler_set_config(mode='all', filename=name)
mx.profiler.profiler_set_state('run')

logging.debug('start training ...')
start = time.time()
data_iter = iter(train_data)
for epoch in range(num_epoch):
nbatch = 0
end_of_batch = False
data_iter.reset()
metric.reset()
next_batch = next(data_iter)
while not end_of_batch:
nbatch += 1
batch = next_batch
# TODO(haibin) remove extra copy after Jun's change
row_ids = batch.data[0].indices.copyto(mx.cpu())
# pull sparse weight
index = mod._exec_group.param_names.index('v')
kv.row_sparse_pull('v', mod._exec_group.param_arrays[index],
priority=-index, row_ids=[row_ids])
mod.forward_backward(batch)
# update parameters
mod.update()
try:
# pre fetch next batch
next_batch = next(data_iter)
if nbatch == num_batch:
raise StopIteration
except StopIteration:
end_of_batch = True
# accumulate prediction accuracy
mod.update_metric(metric, batch.label)
logging.info('epoch %d, %s' % (epoch, metric.get()))
if profiler:
mx.profiler.profiler_set_state('stop')
end = time.time()
time_cost = end - start
logging.info('num_worker = ' + str(num_worker) + ', time cost = ' + str(time_cost))
15 changes: 7 additions & 8 deletions python/mxnet/kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import ctypes
import pickle
from .ndarray import NDArray
from .sparse_ndarray import _ndarray_cls
from .ndarray import _ndarray_cls
from .base import _LIB
from .base import check_call, c_array, c_str, string_types, mx_uint, py_str
from .base import NDArrayHandle, KVStoreHandle
Expand Down Expand Up @@ -221,10 +221,10 @@ def pull(self, key, out=None, priority=0):
out = [out]
for val in out:
if not isinstance(val, (list, tuple)):
assert(val.storage_type == 'default')
assert(val.stype == 'default')
else:
for v in val:
assert(v.storage_type == 'default')
assert(v.stype == 'default')
ckeys, cvals = _ctype_key_value(key, out)
check_call(_LIB.MXKVStorePullEx(
self.handle, mx_uint(len(ckeys)), ckeys, cvals,
Expand All @@ -245,7 +245,7 @@ def row_sparse_pull(self, key, out=None, priority=0, row_ids=None):
Keys.
out: NDArray or list of NDArray or list of list of NDArray
Values corresponding to the keys. The storage_type is expected to be row_sparse
Values corresponding to the keys. The stype is expected to be row_sparse
priority : int, optional
The priority of the pull operation.
Expand Down Expand Up @@ -287,14 +287,13 @@ def row_sparse_pull(self, key, out=None, priority=0, row_ids=None):
out = [out]
for val in out:
if not isinstance(val, (list, tuple)):
assert(val.storage_type == 'row_sparse')
assert(val.stype == 'row_sparse')
else:
for v in val:
assert(v.storage_type == 'row_sparse')
assert(v.stype == 'row_sparse')
ckeys, cvals = _ctype_key_value(key, out)
_, crow_ids = _ctype_key_value(key, row_ids)
assert(len(crow_ids) == len(cvals)), (len(crow_ids), len(cvals))
#TODO(haibin) pickup upstream changes which removed `_cast_to_str_keys`
assert(len(crow_ids) == len(cvals)), "number of row_ids doesn't match number of values"

check_call(_LIB.MXKVStorePullRowSparse(
self.handle, mx_uint(len(ckeys)), ckeys, cvals, crow_ids, ctypes.c_int(priority)))
Expand Down
42 changes: 26 additions & 16 deletions python/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,23 +76,31 @@ def _create_kvstore(kvstore, num_device, arg_params):

return (kv, update_on_kvstore)

def _initialize_kvstore(kvstore, param_arrays, arg_params, param_names,
update_on_kvstore, sparse_pull_dict=None):
def _contains_non_default_storage(params):
if isinstance(params, (list, tuple)):
for param in params:
if param.stype != 'default':
return True
elif isinstance(params, NDArray):
return param.stype != 'default'
else:
return False

def _initialize_kvstore(kvstore, param_arrays, arg_params, param_names, update_on_kvstore):
"""Initialize kvstore"""
for idx, param_on_devs in enumerate(param_arrays):
name = param_names[idx]
kvstore.init(name, arg_params[name])

if update_on_kvstore:
if sparse_pull_dict is not None and name in sparse_pull_dict:
kvstore.row_sparse_pull(name, param_on_devs, priority=-idx,
row_ids=sparse_pull_dict[name])
if _contains_non_default_storage(param_on_devs):
# skip pulling row_sparse weights
warnings.warn('Detected non-default weight in kvstore to pull. Please make ' \
'sure to pull it with row_ids explicitly', RuntimeWarning)
else:
kvstore.pull(name, param_on_devs, priority=-idx)

def _update_params_on_kvstore(param_arrays, grad_arrays, kvstore, param_names,
sparse_pull_dict=None):

def _update_params_on_kvstore(param_arrays, grad_arrays, kvstore, param_names):
"""Perform update of param_arrays from grad_arrays on kvstore."""
for index, pair in enumerate(zip(param_arrays, grad_arrays)):
arg_list, grad_list = pair
Expand All @@ -102,14 +110,15 @@ def _update_params_on_kvstore(param_arrays, grad_arrays, kvstore, param_names,
# push gradient, priority is negative index
kvstore.push(name, grad_list, priority=-index)
# pull back the weights
if sparse_pull_dict is not None and name in sparse_pull_dict:
kvstore.row_sparse_pull(name, arg_list, priority=-index,
row_ids=sparse_pull_dict[name])
if _contains_non_default_storage(arg_list):
# skip pulling row_sparse weights
warnings.warn('Detected non-default weight in kvstore to pull. Please make ' \
'sure to pull it with row_ids', RuntimeWarning)
else:
kvstore.pull(name, arg_list, priority=-index)

def _update_params(param_arrays, grad_arrays, updater, num_device,
kvstore=None, param_names=None, sparse_pull_dict=None):
kvstore=None, param_names=None):
"""Perform update of param_arrays from grad_arrays not on kvstore."""
for i, pair in enumerate(zip(param_arrays, grad_arrays)):
arg_list, grad_list = pair
Expand All @@ -120,11 +129,12 @@ def _update_params(param_arrays, grad_arrays, updater, num_device,
name = param_names[index]
# push gradient, priority is negative index
kvstore.push(name, grad_list, priority=-index)
if sparse_pull_dict is not None and name in sparse_pull_dict:
kvstore.row_sparse_pull(name, grad_list, priority=-index,
row_ids=sparse_pull_dict[name])
# pull back the sum gradients, to the same locations.
if _contains_non_default_storage(grad_list):
# skip pulling row_sparse weights
warnings.warn('Detected non-default weight in kvstore to pull. Please make ' \
'sure to pull it with row_ids', RuntimeWarning)
else:
# pull back the sum gradients, to the same locations.
kvstore.pull(name, grad_list, priority=-index)
for k, p in enumerate(zip(arg_list, grad_list)):
# faked an index here, to make optimizer create diff
Expand Down
3 changes: 1 addition & 2 deletions python/mxnet/module/base_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,8 +932,7 @@ def bind(self, data_shapes, label_shapes=None, for_training=True,
raise NotImplementedError()

def init_optimizer(self, kvstore='local', optimizer='sgd',
optimizer_params=(('learning_rate', 0.01),), force_init=False,
sparse_pull_dict=None):
optimizer_params=(('learning_rate', 0.01),), force_init=False):
"""Installs and initializes optimizers, as well as initialize kvstore for
distributed training
Expand Down
18 changes: 5 additions & 13 deletions python/mxnet/module/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,8 +429,7 @@ def reshape(self, data_shapes, label_shapes=None):
self._exec_group.reshape(self._data_shapes, self._label_shapes)

def init_optimizer(self, kvstore='local', optimizer='sgd',
optimizer_params=(('learning_rate', 0.01),), force_init=False,
sparse_pull_dict=None):
optimizer_params=(('learning_rate', 0.01),), force_init=False):
"""Installs and initializes optimizers.
Parameters
Expand All @@ -445,10 +444,6 @@ def init_optimizer(self, kvstore='local', optimizer='sgd',
force_init : bool
Default ``False``, indicating whether we should force re-initializing the
optimizer in the case an optimizer is already installed.
sparse_pull_dict : dict of str -> list of NDArray
Default to `None`, used for distributed training with sparse parameters.
When the name of a row_sparse parameter is in the dict, the initial value pulled
to devices will only contain the rows specified by the list of row_id NDArrays.
"""
assert self.binded and self.params_initialized

Expand Down Expand Up @@ -502,8 +497,7 @@ def init_optimizer(self, kvstore='local', optimizer='sgd',
param_arrays=self._exec_group.param_arrays,
arg_params=self._arg_params,
param_names=self._param_names,
update_on_kvstore=update_on_kvstore,
sparse_pull_dict=sparse_pull_dict)
update_on_kvstore=update_on_kvstore)
if update_on_kvstore:
kvstore.set_optimizer(self._optimizer)
else:
Expand Down Expand Up @@ -564,7 +558,7 @@ def backward(self, out_grads=None):
assert self.binded and self.params_initialized
self._exec_group.backward(out_grads=out_grads)

def update(self, sparse_pull_dict=None):
def update(self):
"""Updates parameters according to the installed optimizer and the gradients computed
in the previous forward-backward batch.
Expand All @@ -578,16 +572,14 @@ def update(self, sparse_pull_dict=None):
if self._update_on_kvstore:
_update_params_on_kvstore(self._exec_group.param_arrays,
self._exec_group.grad_arrays,
self._kvstore, self._exec_group.param_names,
sparse_pull_dict=sparse_pull_dict)
self._kvstore, self._exec_group.param_names)
else:
_update_params(self._exec_group.param_arrays,
self._exec_group.grad_arrays,
updater=self._updater,
num_device=len(self._context),
kvstore=self._kvstore,
param_names=self._exec_group.param_names,
sparse_pull_dict=sparse_pull_dict)
param_names=self._exec_group.param_names)

def get_outputs(self, merge_multi_context=True):
"""Gets outputs of the previous forward computation.
Expand Down
Loading

0 comments on commit 60cac0b

Please sign in to comment.