Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

RowSparse pull/push #7015

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions example/sparse/get_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# pylint: skip-file
import os, gzip
import pickle as pickle
import sys

def get_libsvm_data(data_dir, data_name, url, data_origin_name):
if not os.path.isdir(data_dir):
os.system("mkdir " + data_dir)
os.chdir(data_dir)
if (not os.path.exists(data_name)):
import urllib
zippath = os.path.join(data_dir, data_origin_name)
urllib.urlretrieve(url, zippath)
os.system("bzip2 -d %r" % data_origin_name)
os.chdir("..")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using test_utils.download

178 changes: 178 additions & 0 deletions example/sparse/linear_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import mxnet as mx
from mxnet.test_utils import *
from get_data import get_libsvm_data
import time
import argparse
import os

parser = argparse.ArgumentParser(description="Run sparse linear regression " \
"with distributed kvstore",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--profiler', type=int, default=0,
help='whether to use profiler')
parser.add_argument('--num-epoch', type=int, default=1,
help='number of epochs to train')
parser.add_argument('--batch-size', type=int, default=512,
help='number of examples per batch')
parser.add_argument('--num-batch', type=int, default=99999999,
help='number of batches per epoch')
parser.add_argument('--dummy-iter', type=int, default=0,
help='whether to use dummy iterator to exclude io cost')
parser.add_argument('--kvstore', type=str, default='dist_sync',
help='what kvstore to use [local, dist_sync, etc]')
parser.add_argument('--log-level', type=str, default='debug',
help='logging level [debug, info, error]')
parser.add_argument('--dataset', type=str, default='avazu',
help='what test dataset to use')

class DummyIter(mx.io.DataIter):
"A dummy iterator that always return the same batch, used for speed testing"
def __init__(self, real_iter):
super(DummyIter, self).__init__()
self.real_iter = real_iter
self.provide_data = real_iter.provide_data
self.provide_label = real_iter.provide_label
self.batch_size = real_iter.batch_size

for batch in real_iter:
self.the_batch = batch
break

def __iter__(self):
return self

def next(self):
return self.the_batch

# testing dataset sources
avazu = {
'data_name': 'avazu-app.t',
'data_origin_name': 'avazu-app.t.bz2',
'url': "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/avazu-app.t.bz2",
'feature_dim': 1000000,
}

kdda = {
'data_name': 'kdda.t',
'data_origin_name': 'kdda.t.bz2',
'url': "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/kdda.t.bz2",
'feature_dim': 20216830,
}

datasets = { 'kdda' : kdda, 'avazu' : avazu }

def regression_model(feature_dim):
initializer = mx.initializer.Normal()
x = mx.symbol.Variable("data", stype='csr')
norm_init = mx.initializer.Normal(sigma=0.01)
v = mx.symbol.Variable("v", shape=(feature_dim, 1), init=norm_init, stype='row_sparse')
embed = mx.symbol.dot(x, v)
y = mx.symbol.Variable("softmax_label")
model = mx.symbol.LinearRegressionOutput(data=embed, label=y, name="out")
return model

if __name__ == '__main__':

# arg parser
args = parser.parse_args()
num_epoch = args.num_epoch
num_batch = args.num_batch
kvstore = args.kvstore
profiler = args.profiler > 0
batch_size = args.batch_size
dummy_iter = args.dummy_iter
dataset = args.dataset
log_level = args.log_level

# create kvstore
kv = mx.kvstore.create(kvstore)
rank = kv.rank
num_worker = kv.num_workers

# only print log for rank 0 worker
import logging
if rank != 0:
log_level = logging.ERROR
elif log_level == 'DEBUG':
log_level = logging.DEBUG
else:
log_level = logging.INFO
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=log_level, format=head)

# dataset
assert(dataset in datasets), "unknown dataset " + dataset
metadata = datasets[dataset]
feature_dim = metadata['feature_dim']
if logging:
logging.debug('preparing data ... ')
data_dir = os.path.join(os.getcwd(), 'data')
path = os.path.join(data_dir, metadata['data_name'])
if not os.path.exists(path):
get_libsvm_data(data_dir, metadata['data_name'], metadata['url'],
metadata['data_origin_name'])
assert os.path.exists(path)

# data iterator
train_data = mx.io.LibSVMIter(data_libsvm=path, data_shape=(feature_dim,),
batch_size=batch_size, num_parts=num_worker,
part_index=rank)
if dummy_iter:
train_data = DummyIter(train_data)

# model
model = regression_model(feature_dim)

# module
mod = mx.mod.Module(symbol=model, data_names=['data'], label_names=['softmax_label'])
mod.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label)
mod.init_params(initializer=mx.init.Uniform(scale=.1))
sgd = mx.optimizer.SGD(momentum=0.0, clip_gradient=5.0,
learning_rate=0.1, rescale_grad=1.0/batch_size/num_worker)
mod.init_optimizer(optimizer=sgd, kvstore=kv)
# use accuracy as the metric
metric = mx.metric.create('MSE')

# start profiler
if profiler:
import random
name = 'profile_output_' + str(num_worker) + '.json'
mx.profiler.profiler_set_config(mode='all', filename=name)
mx.profiler.profiler_set_state('run')

logging.debug('start training ...')
start = time.time()
data_iter = iter(train_data)
for epoch in range(num_epoch):
nbatch = 0
end_of_batch = False
data_iter.reset()
metric.reset()
next_batch = next(data_iter)
while not end_of_batch:
nbatch += 1
batch = next_batch
# TODO(haibin) remove extra copy after Jun's change
row_ids = batch.data[0].indices.copyto(mx.cpu())
# pull sparse weight
index = mod._exec_group.param_names.index('v')
kv.row_sparse_pull('v', mod._exec_group.param_arrays[index],
priority=-index, row_ids=[row_ids])
mod.forward_backward(batch)
# update parameters
mod.update()
try:
# pre fetch next batch
next_batch = next(data_iter)
if nbatch == num_batch:
raise StopIteration
except StopIteration:
end_of_batch = True
# accumulate prediction accuracy
mod.update_metric(metric, batch.label)
logging.info('epoch %d, %s' % (epoch, metric.get()))
if profiler:
mx.profiler.profiler_set_state('stop')
end = time.time()
time_cost = end - start
logging.info('num_worker = ' + str(num_worker) + ', time cost = ' + str(time_cost))
20 changes: 20 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1505,6 +1505,26 @@ MXNET_DLL int MXKVStorePullEx(KVStoreHandle handle,
const char** keys,
NDArrayHandle* vals,
int priority);

/*!
* \brief pull a list of (key, value) pairs from the kvstore, where each key is a string.
* The NDArray pulled back will be in row_sparse storage with only the specified
* row_ids present based row_ids (others rows are zeros).
* \param handle handle to the kvstore
* \param num the number of key-value pairs
* \param keys the list of keys
* \param vals the list of values
* \param row_ids the list of row_id NDArrays
* \param priority the priority of the action
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXKVStorePullRowSparse(KVStoreHandle handle,
mx_uint num,
const char** keys,
NDArrayHandle* vals,
const NDArrayHandle* row_ids,
int priority);

/*!
* \brief user-defined updater for the kvstore
* It's this updater's responsibility to delete \a recv and \a local
Expand Down
24 changes: 24 additions & 0 deletions include/mxnet/kvstore.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#define MXNET_KVSTORE_H_
#include <dmlc/io.h>
#include <vector>
#include <utility>
#include <unordered_map>
#include <string>
#include <functional>
Expand Down Expand Up @@ -155,6 +156,29 @@ class KVStore {
const std::vector<NDArray*>& values,
int priority = 0) = 0;

/*!
* \brief pull a list of key-value pairs from the store.
* The NDArray pulled back will be in row_sparse storage with only the
* specified row_ids present (others rows are zeros).
* \param keys the list of keys
* \param values the list of buffers - row_id pairs
* \param priority the priority of the action.
*/
virtual void PullRowSparse(const std::vector<int>& str_keys,
const std::vector<std::pair<NDArray*, NDArray>>& val_rowids,
const int priority = 0) = 0;

/*!
* \brief pull a list of key-value pairs from the store, where each key is a string.
* The NDArray pulled back will be in row_sparse storage with only the
* specified row_ids present (others rows are zeros).
* \param keys the list of keys in string format
* \param values the list of buffers - row_id pairs
* \param priority the priority of the action.
*/
virtual void PullRowSparse(const std::vector<std::string>& str_keys,
const std::vector<std::pair<NDArray*, NDArray>>& val_rowids,
const int priority = 0) = 0;

/**
* \brief the prototype of user-defined updater
Expand Down
11 changes: 8 additions & 3 deletions include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,13 @@ class NDArray {
return shape_;
}
/*!
* \return the shape of underlying chunk which stores the NDArray values.
* For default storage, it is the same as shape(). For row-sparse storage, it is the shape of
* \return the shape of underlying chunk which stores the NDArray data/value.
* It is only intended for non-default storage. For row-sparse storage, it is the shape of
* the tensor which stores the non-zero values.
*/
inline const TShape &storage_shape() const {
CHECK(ptr_ != nullptr);
CHECK_NE(storage_type(), kDefaultStorage);
return ptr_->storage_shape;
}

Expand Down Expand Up @@ -271,7 +272,11 @@ class NDArray {
if (is_none()) return false;
auto stype = storage_type();
CHECK_NE(stype, kDefaultStorage);
if (stype == kRowSparseStorage || stype == kCSRStorage) {
if (stype == kRowSparseStorage) {
CHECK_EQ(aux_shape(rowsparse::kIdx)[0], storage_shape()[0]);
return aux_shape(0).Size() != 0;
} else if (stype == kCSRStorage) {
CHECK_EQ(aux_shape(csr::kIdx)[0], storage_shape()[0]);
return aux_shape(0).Size() != 0;
} else {
LOG(FATAL) << "Unknown storage type";
Expand Down
Loading