Skip to content

Commit

Permalink
Predict.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Mar 12, 2020
1 parent 9215870 commit cad8501
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
3 changes: 1 addition & 2 deletions demo/dask/cpu_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ def main(client):
'nthread': 1,
'tree_method': 'hist'},
dtrain,
num_boost_round=4, evals=[(dtrain, 'train')],
host_ip='127.0.0.1', port=32809)
num_boost_round=4, evals=[(dtrain, 'train')])
bst = output['booster']
history = output['history']

Expand Down
11 changes: 6 additions & 5 deletions python-package/xgboost/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,8 @@ def get_worker_data_shape(self, worker):

def _get_rabit_args(worker_map, client, host_ip=None, port=None):
'''Get rabit context arguments from data distribution in DaskDMatrix.'''
assert (host_ip and port) or (host_ip is None and port is None)
msg = 'Please provide both IP and port'
assert (host_ip and port) or (host_ip is None and port is None), msg

if host_ip:
logger.info('Running tracker on: %s, %s', host_ip, str(port))
Expand All @@ -363,7 +364,7 @@ def _get_rabit_args(worker_map, client, host_ip=None, port=None):
# evaluation history is instead returned.


def train(client, params, dtrain, *args, evals=(), host_ip=None, port=None,
def train(client, params, dtrain, *args, evals=(), tracker_ip=None, tracker_port=None,
**kwargs):
'''Train XGBoost model.
Expand Down Expand Up @@ -399,7 +400,7 @@ def train(client, params, dtrain, *args, evals=(), host_ip=None, port=None,
client = _xgb_get_client(client)
workers = list(_get_client_workers(client).keys())

rabit_args = _get_rabit_args(workers, client, host_ip, port)
rabit_args = _get_rabit_args(workers, client, tracker_ip, tracker_port)

def dispatched_train(worker_addr):
'''Perform training on a single worker.'''
Expand Down Expand Up @@ -438,7 +439,7 @@ def dispatched_train(worker_addr):
return list(filter(lambda ret: ret is not None, results))[0]


def predict(client, model, data, *args):
def predict(client, model, data, *args, tracker_ip=None, tracker_port=None):
'''Run prediction with a trained booster.
.. note::
Expand Down Expand Up @@ -475,7 +476,7 @@ def predict(client, model, data, *args):
worker_map = data.worker_map
client = _xgb_get_client(client)

rabit_args = _get_rabit_args(worker_map, client)
rabit_args = _get_rabit_args(worker_map, client, tracker_ip, tracker_port)

def dispatched_predict(worker_id):
'''Perform prediction on each worker.'''
Expand Down

0 comments on commit cad8501

Please sign in to comment.