Skip to content

Commit

Permalink
faster tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb committed Jan 19, 2021
1 parent 4397022 commit 6428589
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
22 changes: 18 additions & 4 deletions python-package/lightgbm/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,15 +228,27 @@ def _train(client, data, label, params, model_factory, weight=None, **kwargs):
return results[0]


def _predict_part(part, model, pred_proba, pred_leaf, pred_contrib, **kwargs):
def _predict_part(part, model, raw_score, pred_proba, pred_leaf, pred_contrib, **kwargs):
data = part.values if isinstance(part, pd.DataFrame) else part

if data.shape[0] == 0:
result = np.array([])
elif pred_proba:
result = model.predict_proba(data, pred_leaf=pred_leaf, pred_contrib=pred_contrib, **kwargs)
result = model.predict_proba(
data,
raw_score=raw_score,
pred_leaf=pred_leaf,
pred_contrib=pred_contrib,
**kwargs
)
else:
result = model.predict(data, pred_leaf=pred_leaf, pred_contrib=pred_contrib, **kwargs)
result = model.predict(
data,
raw_score=raw_score,
pred_leaf=pred_leaf,
pred_contrib=pred_contrib,
**kwargs
)

if isinstance(part, pd.DataFrame):
if pred_proba or pred_contrib:
Expand All @@ -247,7 +259,7 @@ def _predict_part(part, model, pred_proba, pred_leaf, pred_contrib, **kwargs):
return result


def _predict(model, data, pred_proba=False, pred_leaf=False, pred_contrib=False,
def _predict(model, data, raw_score=False, pred_proba=False, pred_leaf=False, pred_contrib=False,
dtype=np.float32, **kwargs):
"""Inner predict routine.
Expand All @@ -270,6 +282,7 @@ def _predict(model, data, pred_proba=False, pred_leaf=False, pred_contrib=False,
return data.map_partitions(
_predict_part,
model=model,
raw_score=raw_score,
pred_proba=pred_proba,
pred_leaf=pred_leaf,
pred_contrib=pred_contrib,
Expand All @@ -283,6 +296,7 @@ def _predict(model, data, pred_proba=False, pred_leaf=False, pred_contrib=False,
return data.map_blocks(
_predict_part,
model=model,
raw_score=raw_score,
pred_proba=pred_proba,
pred_leaf=pred_leaf,
pred_contrib=pred_contrib,
Expand Down
2 changes: 1 addition & 1 deletion tests/python_package_test/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def test_classifier_pred_contrib(output, centers, client, listen_port):
else:
expected_num_cols = (dX.shape[1] + 1) * num_classes

if isinstance(dX, dask.dataframe.core.DataFrame):
if isinstance(dX, dask.dataframe.DataFrame):
assert preds_with_contrib.shape == (dX.shape[0].compute(), expected_num_cols)
else:
assert preds_with_contrib.shape == (dX.shape[0], expected_num_cols)
Expand Down

0 comments on commit 6428589

Please sign in to comment.