Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python-package] [dask] Add DaskLGBMRanker #3708

Merged
merged 20 commits into from
Jan 22, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .ci/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ if [[ $TRAVIS == "true" ]] && [[ $TASK == "lint" ]]; then
"r-lintr>=2.0"
pip install --user cpplint
echo "Linting Python code"
pycodestyle --ignore=E501,W503 --exclude=./compute,./eigen,./.nuget,./external_libs . || exit -1
pycodestyle --ignore=E402,E501,W503 --exclude=./compute,./eigen,./.nuget,./external_libs . || exit -1
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
pydocstyle --convention=numpy --add-ignore=D105 --match-dir="^(?!^compute|^eigen|external_libs|test|example).*" --match="(?!^test_|setup).*\.py" . || exit -1
echo "Linting R code"
Rscript ${BUILD_DIRECTORY}/.ci/lint_r_code.R ${BUILD_DIRECTORY} || exit -1
Expand Down
82 changes: 61 additions & 21 deletions python-package/lightgbm/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from dask.distributed import default_client, get_worker, wait

from .basic import _LIB, _safe_call
from .sklearn import LGBMClassifier, LGBMRegressor
from .sklearn import LGBMClassifier, LGBMRegressor, LGBMRanker

import scipy.sparse as ss

Expand Down Expand Up @@ -68,15 +68,22 @@ def _train_part(params, model_factory, list_of_parts, worker_addresses, return_m
network_params = _build_network_params(worker_addresses, get_worker().address, local_listen_port, time_out)
params.update(network_params)

is_ranker = model_factory.__qualname__ == 'LGBMRanker'
ffineis marked this conversation as resolved.
Show resolved Hide resolved

# Concatenate many parts into one
parts = tuple(zip(*list_of_parts))
data = _concat(parts[0])
label = _concat(parts[1])
weight = _concat(parts[2]) if len(parts) == 3 else None
data = _concat([d['X'] for d in list_of_parts])
label = _concat([d['y'] for d in list_of_parts])
weight = _concat([d['weight'] for d in list_of_parts]) if 'weight' in list_of_parts[0] else None
jameslamb marked this conversation as resolved.
Show resolved Hide resolved

try:
model = model_factory(**params)
model.fit(data, label, sample_weight=weight, **kwargs)

if is_ranker:
group = _concat([d['group'] for d in list_of_parts])
model.fit(data, y=label, sample_weight=weight, group=group, **kwargs)
else:
model.fit(data, y=label, sample_weight=weight, **kwargs)

finally:
_safe_call(_LIB.LGBM_NetworkFree())

Expand All @@ -91,7 +98,7 @@ def _split_to_parts(data, is_matrix):
return parts


def _train(client, data, label, params, model_factory, weight=None, **kwargs):
def _train(client, data, label, params, model_factory, sample_weight=None, group=None, **kwargs):
"""Inner train routine.

Parameters
Expand All @@ -102,20 +109,30 @@ def _train(client, data, label, params, model_factory, weight=None, **kwargs):
y : dask array of shape = [n_samples]
The target values (class labels in classification, real numbers in regression).
params : dict
model_factory : lightgbm.LGBMClassifier or lightgbm.LGBMRegressor class
model_factory : lightgbm.LGBMClassifier, lightgbm.LGBMRegressor, or lightgbm.LGBMRanker class
sample_weight : array-like of shape = [n_samples] or None, optional (default=None)
Weights of training data.
Weights of training data.
group : array-like
Group/query data, used for ranking task. sum(group) = n_samples.
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
"""
# Split arrays/dataframes into parts. Arrange parts into tuples to enforce co-locality
# Split arrays/dataframes into parts. Arrange parts into dicts to enforce co-locality
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
data_parts = _split_to_parts(data, is_matrix=True)
label_parts = _split_to_parts(label, is_matrix=False)
if weight is None:
parts = list(map(delayed, zip(data_parts, label_parts)))
else:
weight_parts = _split_to_parts(weight, is_matrix=False)
parts = list(map(delayed, zip(data_parts, label_parts, weight_parts)))
parts = [{'X': x, 'y': y} for (x, y) in zip(data_parts, label_parts)]

# append weight, group vectors to part dicts when needed.
if sample_weight is not None:
weight_parts = _split_to_parts(sample_weight, is_matrix=False)
for i, d in enumerate(parts):
parts[i] = {**d, 'weight': weight_parts[i]}

if group is not None:
group_parts = _split_to_parts(group, is_matrix=False)
for i, d in enumerate(parts):
parts[i] = {**d, 'group': group_parts[i]}

# Start computation in the background
parts = list(map(delayed, parts))
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
parts = client.compute(parts)
wait(parts)

Expand Down Expand Up @@ -179,7 +196,7 @@ def _predict(model, data, proba=False, dtype=np.float32, **kwargs):

Parameters
----------
model :
model : local lightgbm.LGBM[Classifier/Regressor/Ranker]
StrikerRUS marked this conversation as resolved.
Show resolved Hide resolved
data : dask array of shape = [n_samples, n_features]
Input feature matrix.
proba : bool
Expand All @@ -202,13 +219,13 @@ def _predict(model, data, proba=False, dtype=np.float32, **kwargs):

class _LGBMModel:

def _fit(self, model_factory, X, y=None, sample_weight=None, client=None, **kwargs):
def _fit(self, model_factory, X, y=None, sample_weight=None, group=None, client=None, **kwargs):
"""Docstring is inherited from the LGBMModel."""
if client is None:
client = default_client()

params = self.get_params(True)
model = _train(client, X, y, params, model_factory, sample_weight, **kwargs)
model = _train(client, X, y, params, model_factory, sample_weight, group, **kwargs)
StrikerRUS marked this conversation as resolved.
Show resolved Hide resolved

self.set_params(**model.get_params())
self._copy_extra_params(model, self)
Expand All @@ -233,8 +250,8 @@ class DaskLGBMClassifier(_LGBMModel, LGBMClassifier):
"""Distributed version of lightgbm.LGBMClassifier."""

def fit(self, X, y=None, sample_weight=None, client=None, **kwargs):
"""Docstring is inherited from the LGBMModel."""
return self._fit(LGBMClassifier, X, y, sample_weight, client, **kwargs)
"""Docstring is inherited from the lightgbm.LGBMClassifier.fit."""
return self._fit(LGBMClassifier, X=X, y=y, sample_weight=sample_weight, client=client, **kwargs)
fit.__doc__ = LGBMClassifier.fit.__doc__

def predict(self, X, **kwargs):
Expand Down Expand Up @@ -262,7 +279,7 @@ class DaskLGBMRegressor(_LGBMModel, LGBMRegressor):

def fit(self, X, y=None, sample_weight=None, client=None, **kwargs):
"""Docstring is inherited from the lightgbm.LGBMRegressor.fit."""
return self._fit(LGBMRegressor, X, y, sample_weight, client, **kwargs)
return self._fit(LGBMRegressor, X=X, y=y, sample_weight=sample_weight, client=client, **kwargs)
fit.__doc__ = LGBMRegressor.fit.__doc__

def predict(self, X, **kwargs):
Expand All @@ -278,3 +295,26 @@ def to_local(self):
model : lightgbm.LGBMRegressor
"""
return self._to_local(LGBMRegressor)


class DaskLGBMRanker(_LGBMModel, LGBMRanker):
"""Docstring is inherited from the lightgbm.LGBMRanker."""

def fit(self, X, y=None, sample_weight=None, group=None, client=None, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jameslamb What about init_score? Is it supported or we should add feature request for it?

init_score : array-like of shape = [n_samples] or None, optional (default=None)
Init score of training data.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should have a feature request. I'll write it up and add a link here.

@ffineis, could you add init_score=None here between sample_weight and group, so the order matches the sklearn interface for LGBMRanker? (

sample_weight=None, init_score=None, group=None,
). That way, if people have existing sklearn code with positional arguments to fit(), they won't accidentally have their init_score interpreted as group.

And can you just then add a check like this?

if init_score is not None:
    raise RuntimeError("init_score is not currently supported in lightgbm.dask")

@StrikerRUS speaking of positional arguments, I'll open another issue where we can discuss how to handle the client argument. But let's leave that out of this PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jameslamb
Yes, sure! Agree with all your intents.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

init_score: #3807

client placement: #3808

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, no prob

"""Docstring is inherited from the lightgbm.LGBMRanker.fit."""
return self._fit(LGBMRanker, X=X, y=y, sample_weight=sample_weight, group=group, client=client, **kwargs)
fit.__doc__ = LGBMRanker.fit.__doc__

def predict(self, X, **kwargs):
"""Docstring is inherited from the lightgbm.LGBMRanker.predict."""
return _predict(self.to_local(), X, **kwargs)
predict.__doc__ = LGBMRanker.predict.__doc__

def to_local(self):
"""Create regular version of lightgbm.LGBMRanker from the distributed version.

Returns
-------
model : lightgbm.LGBMRanker
"""
return self._to_local(LGBMRanker)
161 changes: 153 additions & 8 deletions tests/python_package_test/test_dask.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,28 @@
# coding: utf-8
"""Tests for lightgbm.dask module

An easy way to run these tests is from the (python) docker container.
> python -m pytest /LightGBM/tests/python_package_test/test_dask.py
"""
import os
import itertools
import sys
import time

import pytest
if not sys.platform.startswith("linux"):
pytest.skip("lightgbm.dask is currently supported in Linux environments", allow_module_level=True)
if not sys.platform.startswith('linux'):
pytest.skip('lightgbm.dask is currently supported in Linux environments', allow_module_level=True)

import dask.array as da
import dask.dataframe as dd
import numpy as np
import pandas as pd
from scipy.stats import spearmanr
import scipy.sparse
from dask.array.utils import assert_eq
from dask_ml.metrics import accuracy_score, r2_score
from distributed.utils_test import client, cluster_fixture, gen_cluster, loop
from sklearn.datasets import make_blobs, make_regression
from sklearn.utils import check_random_state

import lightgbm
import lightgbm.dask as dlgbm
Expand All @@ -23,7 +31,7 @@
data_centers = [[[-4, -4], [4, 4]], [[-4, -4], [4, 4], [-4, 4]]]

pytestmark = [
pytest.mark.skipif(os.getenv("TASK", "") == "mpi", reason="Fails to run with MPI interface")
pytest.mark.skipif(os.getenv('TASK', '') == 'mpi', reason='Fails to run with MPI interface')
]


Expand All @@ -36,6 +44,102 @@ def listen_port():
listen_port.port = 13000


def r2_score(dy_true, dy_pred):
"""Helper function taken from dask_ml.metrics: computes coefficient of determination."""
numerator = ((dy_true - dy_pred) ** 2).sum(axis=0)
denominator = ((dy_true - dy_pred.mean(axis=0)) ** 2).sum(axis=0)
return (1 - numerator / denominator).compute()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why was it necessary to replace the dask-ml one with this function?

If that's just to cut out our testing dependency on dask-ml (which would be much appreciated!), could you make that a separate PR with this change plus removing dask-ml from here:

  • conda install -q -y -n $CONDA_ENV dask dask-ml distributed joblib matplotlib numpy pandas psutil pytest scikit-learn scipy

I think that PR would be small and non-controversial, so we could review and merge it quickly.

Copy link
Contributor Author

@ffineis ffineis Jan 16, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, yeah I'll follow up 3708 with a PR to remove of dask-ml as a test dependency. See reply to comment re: accuracy_score (above).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in d6e5209

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hey @ffineis is it cool if I pick up the "remove testing dependency on dask-ml" thing?

If you want to make the contribution I'll totally respect that, but if it's just extra work you'd rather not do, I'm happy to make a PR right now for it. Wanna be respectful of your time.



def _make_ranking(n_samples=100, n_features=20, n_informative=5, gmax=1, random_gs=False, avg_gs=10, random_state=0):
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
"""Generate a learning-to-rank dataset - feature vectors grouped together with
integer-valued graded relevance scores. Replace this with a sklearn.datasets function
if ranking objective becomes supported in sklearn.datasets module."""
rnd_generator = check_random_state(random_state)

y_vec, group_vec = np.empty((0,), dtype=int), np.empty((0,), dtype=int)
gid = 0

# build target, group ID vectors.
relvalues = range(gmax + 1)
while len(y_vec) < n_samples:
gsize = avg_gs if not random_gs else rnd_generator.poisson(avg_gs)
if not gsize:
continue

rel = rnd_generator.choice(relvalues, size=gsize, replace=True)
y_vec = np.append(y_vec, rel)
group_vec = np.append(group_vec, [gid] * gsize)
gid += 1

y_vec, group_vec = y_vec[0:n_samples], group_vec[0:n_samples]

# build feature data, X. Transform first few into informative features.
n_informative = max(min(n_features, n_informative), 0)
x_grid = np.linspace(0, stop=1, num=gmax + 2)
X = rnd_generator.uniform(size=(n_samples, n_features))

# make first n_informative features values bucketed according to relevance scores.
def bucket_fn(z):
return rnd_generator.uniform(x_grid[z], high=x_grid[z + 1])

for j in range(n_informative):
bias, coef = rnd_generator.normal(size=2)
X[:, j] = bias + coef * np.apply_along_axis(bucket_fn, axis=0, arr=y_vec)

return X, y_vec, group_vec


def _create_ranking_data(n_samples=100, output='array', chunk_size=50):
X, y, g = _make_ranking(n_samples=n_samples, random_state=42)
rnd = np.random.RandomState(42)
w = rnd.rand(X.shape[0]) * 0.01
g_rle = np.array([sum([1 for _ in grp]) for _, grp in itertools.groupby(g)])
StrikerRUS marked this conversation as resolved.
Show resolved Hide resolved

if output == 'dataframe':

# add target, weight, and group to DataFrame so that partitions abide by group boundaries.
X_df = pd.DataFrame(X, columns=[f'feature_{i}' for i in range(X.shape[1])])
X = X_df.copy()
X_df = X_df.assign(y=y, g=g, w=w)

# set_index ensures partitions are based on group id. See https://bit.ly/3pAWyNw.
StrikerRUS marked this conversation as resolved.
Show resolved Hide resolved
X_df.set_index('g', inplace=True)
dX = dd.from_pandas(X_df, chunksize=chunk_size)

# separate target, weight from features.
dy = dX['y']
dw = dX['w']
dX = dX.drop(columns=['y', 'w'])
dg = dX.index.to_series()

# encode group identifiers into run-length encoding, the format LightGBMRanker is expecting
# so that within each partition, sum(g) = n_samples.
dg = dg.map_partitions(lambda p: p.groupby('g', sort=False).apply(lambda z: z.shape[0]))

elif output == 'array':

# ranking arrays: one chunk per group. Each chunk must include all columns.
p = X.shape[1]
dX, dy, dw, dg = list(), list(), list(), list()
StrikerRUS marked this conversation as resolved.
Show resolved Hide resolved
for g_idx, rhs in enumerate(np.cumsum(g_rle)):
lhs = rhs - g_rle[g_idx]
dX.append(da.from_array(X[lhs:rhs, :], chunks=(rhs - lhs, p)))
dy.append(da.from_array(y[lhs:rhs]))
dw.append(da.from_array(w[lhs:rhs]))
dg.append(da.from_array(np.array([g_rle[g_idx]])))

dX = da.concatenate(dX, axis=0)
dy = da.concatenate(dy, axis=0)
dw = da.concatenate(dw, axis=0)
dg = da.concatenate(dg, axis=0)

else:
raise ValueError('ranking data creation only supported for Dask arrays and dataframes')
StrikerRUS marked this conversation as resolved.
Show resolved Hide resolved

return X, y, w, g_rle, dX, dy, dw, dg


def _create_data(objective, n_samples=100, centers=2, output='array', chunk_size=50):
if objective == 'classification':
X, y = make_blobs(n_samples=n_samples, centers=centers, random_state=42)
Expand Down Expand Up @@ -74,15 +178,14 @@ def test_classifier(output, centers, client, listen_port):
dask_classifier = dlgbm.DaskLGBMClassifier(time_out=5, local_listen_port=listen_port)
dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw, client=client)
p1 = dask_classifier.predict(dX)
s1 = accuracy_score(dy, p1)
s1 = da.average(dy == p1).compute()
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
p1 = p1.compute()

local_classifier = lightgbm.LGBMClassifier()
local_classifier.fit(X, y, sample_weight=w)
p2 = local_classifier.predict(X)
s2 = local_classifier.score(X, y)

assert_eq(s1, s2)
assert np.isclose(s1, s2)

assert_eq(p1, p2)
assert_eq(y, p1)
Expand Down Expand Up @@ -180,7 +283,49 @@ def test_regressor_local_predict(client, listen_port):

# Predictions and scores should be the same
assert_eq(p1, p2)
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
assert_eq(s1, s2)
assert np.isclose(s1, s2, rtol=1e-4)


@pytest.mark.parametrize('output', ['array', 'dataframe'])
def test_ranker(output, client, listen_port):
X, y, w, g, dX, dy, dw, dg = _create_ranking_data(output=output)
jameslamb marked this conversation as resolved.
Show resolved Hide resolved

# Avoid lightgbm.basic.LightGBMError: Binding port 13xxx failed exceptions.
time.sleep(10)
jameslamb marked this conversation as resolved.
Show resolved Hide resolved

dask_ranker = dlgbm.DaskLGBMRanker(time_out=5, local_listen_port=listen_port, seed=42, min_child_samples=1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just made some changes in #3786 to speed up the tests in this module. If it'll work for the tests you're adding in this PR, can you please make every DaskLGBMRanker and LGBMRanker use n_iterations=10, num_leaves=10? I found that training smaller models was able to make the tests faster without sacrificing any coverage.

dask_ranker = dask_ranker.fit(dX, dy, sample_weight=dw, group=dg, client=client)
rnkvec_dask = dask_ranker.predict(dX)
rnkvec_dask = rnkvec_dask.compute()

local_ranker = lightgbm.LGBMRanker(seed=42, min_child_samples=1)
local_ranker.fit(X, y, sample_weight=w, group=g)
rnkvec_local = local_ranker.predict(X)

# distributed ranker should be able to rank decently well.
dcor = spearmanr(rnkvec_dask, y).correlation
assert dcor > 0.6

# relative difference between distributed ranker and local ranker spearman corr should be small.
lcor = spearmanr(rnkvec_local, y).correlation
assert np.abs(dcor - lcor) / lcor < 0.03
jameslamb marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.parametrize('output', ['array', 'dataframe'])
def test_ranker_local_predict(output, client, listen_port):
X, y, w, g, dX, dy, dw, dg = _create_ranking_data(output=output)

time.sleep(10)

dask_ranker = dlgbm.DaskLGBMRanker(time_out=5, local_listen_port=listen_port, seed=42, min_child_samples=1)
dask_ranker = dask_ranker.fit(dX, dy, group=dg, client=client)
rnkvec_dask = dask_ranker.predict(dX)
rnkvec_dask = rnkvec_dask.compute()

rnkvec_local = dask_ranker.to_local().predict(X)

# distributed and to-local scores should be the same.
assert_eq(rnkvec_dask, rnkvec_local)
Comment on lines +416 to +419
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jameslamb I like this test very much! Can we add the same for other estimators? Seems that Classifier tests do not asserts real Dask predict() equals to .to_local().predict().

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the classifier tests do check that

I'd be open to a PR that makes it more explicit and cuts out the total number of tests by combining test_classifier_local_predict (

def test_classifier_local_predict(client, listen_port):
) into test_classifier (
def test_classifier(output, centers, client, listen_port):
) though!

Not an action we need on this PR, I'll write it up as a "good first issue".

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jameslamb

comparison between Dask predict() and to_local().predict():

Actually this checks to_local().predict(dX) and local predict(). It is not the same as to_local().predict(dX) and Dask predict(), given that asserts you pointing are placed in two different (in theory unrelated) tests, I think.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's confusing because it's two tests, so I think we should merge the tests in the future like I said above (not anything you need to do on here, @ffineis).

I'm saying that taken together, they show transitively that dask_model.predict() == dask_model.to_local().predict(), even though neither one does that direct comparison.

  • (line 92) dask_model.predict() == sklearn_model.predict()
  • (line 138) dask_model.to_local().predict() == sklearn_model.predict()

But that's only partially true, since test_local_predict only does binary classification, and only on array data. So combining them would give more coverage and would show whether that result holds up for multi-class classification or for other input data types.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's confusing because it's two tests, so I think we should merge the tests in the future like I said above

Totally agree!

I'm saying that taken together, they show transitively that dask_model.predict() == dask_model.to_local().predict()

Also agree. But two different tests can take different inputs (params, data), can be modified in future separately, etc. So it can be not fully fair comparison. Merging them into one test will solve all issues. 🙂



def test_build_network_params():
Expand Down