Skip to content

Commit

Permalink
Add waitall to sparse_end2end.py (apache#169)
Browse files Browse the repository at this point in the history
* Add waitall()

* Add dummy metric option

* Add header license
  • Loading branch information
reminisce authored and eric-haibin-lin committed Aug 16, 2017
1 parent 8cae272 commit bed002b
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 1 deletion.
25 changes: 24 additions & 1 deletion benchmark/python/sparse_end2end.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from mxnet.test_utils import *
import time
import argparse
Expand Down Expand Up @@ -27,6 +44,8 @@
'otherwise, use gpu(0),...,gpu(num_gpu-1)')
parser.add_argument('--output-dim', type=int, default=4,
help='number of columns of the forward output')
parser.add_argument('--dummy-metric', type=int, default=0,
help='whether to call update_metric')


def get_libsvm_data(data_dir, data_name, url, data_origin_name):
Expand Down Expand Up @@ -176,6 +195,7 @@ def row_sparse_pull(kv, key, data, slices, weight_array, priority):
# weight_array bound to executors of the contexts
weight_array = mod._exec_group.param_arrays[index]

mx.nd.waitall() # sync point for initialization
# start profiler
if profiler:
device = 'cpu'
Expand Down Expand Up @@ -215,7 +235,10 @@ def row_sparse_pull(kv, key, data, slices, weight_array, priority):
except StopIteration:
end_of_batch = True
# accumulate prediction accuracy
mod.update_metric(metric, batch.label)
if args.dummy_metric == 0:
mod.update_metric(metric, batch.label)
else: # call waitall to replace update_metric as sync point
mx.nd.waitall() # sync point for the current minibatch
logging.info('epoch %d, %s' % (epoch, metric.get()))
if epoch == 0:
print "num_batches = ", nbatch
Expand Down
17 changes: 17 additions & 0 deletions tests/python/gpu/test_kvstore_gpu.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# pylint: skip-file
import mxnet as mx
import numpy as np
Expand Down

0 comments on commit bed002b

Please sign in to comment.