-
Notifications
You must be signed in to change notification settings - Fork 3.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[dask] lightgbm + dask generates crazy predictions #4695
Comments
Could you upgrade |
Thanks very much for your interest in LightGBM, and for all the effort you put into this very clear write-up!
To add more detail, I believe @jmoralez is referencing the fix from #4185 (which addressed #4026). Prior to It looks like the feature print(dask.array.nanmin(X_train.partitions[0], 0).compute())
print(dask.array.nanmax(X_train.partitions[0], 0).compute())
# [-0.40703553 -0.22746199]
# [ 1.45464902 21.05879944]
print(dask.array.nanmin(X_train.partitions[1], 0).compute())
print(dask.array.nanmax(X_train.partitions[1], 0).compute())
# [ 1.45535379 -0.22753906]
# [25.88824717 25.7793674 ] I tested tonight and it does seem that testing code (click me)Note that the code sample code in the issue description cannot be copied and run directly, since it does not contain code for defining I also changed the import numpy as np
import pandas as pd
import dask
import lightgbm as lgb
import sklearn
from distributed import Client, LocalCluster
cluster = LocalCluster(n_workers=2)
client = Client(cluster)
# Prepare data
train_data = pd.read_csv("https://github.com/microsoft/LightGBM/files/7369639/train_data.csv")
train_data_dask = dask.dataframe.from_pandas(train_data, npartitions=2)
X_train = train_data_dask[["x0", "x1"]].to_dask_array(lengths=True)
y_train = train_data_dask["y"].to_dask_array(lengths=True)
w_train = train_data_dask["weight"].to_dask_array(lengths=True)
cluster = LocalCluster(n_workers=2)
client = Client(cluster)
# Prepare data
train_data = pd.read_csv("https://github.com/microsoft/LightGBM/files/7369639/train_data.csv")
train_data_dask = dask.dataframe.from_pandas(train_data, npartitions=2)
X_train = train_data_dask[["x0", "x1"]].to_dask_array(lengths=True)
y_train = train_data_dask["y"].to_dask_array(lengths=True)
w_train = train_data_dask["weight"].to_dask_array(lengths=True)
# Model training and in-sample prediction
model = lgb.DaskLGBMRegressor(
client=client,
max_depth=8,
learning_rate=0.01,
tree_learner="data",
n_estimators=500,
)
model.fit(X_train, y_train, sample_weight=w_train)
y_pred = model.predict(X=X_train)
# Measure the result using r_squared
y_local = y_train.compute()
w_local = w_train.compute()
y_pred_local = y_pred.compute()
r_squared = sklearn.metrics.r2_score(y_local, y_pred_local, sample_weight=w_local)
print(f"r_squared: {r_squared}")
# --- check if any features have non-overlapping distributions across partitions --- #
print(dask.array.nanmin(X_train.partitions[0], 0).compute())
print(dask.array.nanmax(X_train.partitions[0], 0).compute())
# [-0.40703553 -0.22746199]
# [ 1.45464902 21.05879944]
print(dask.array.nanmin(X_train.partitions[1], 0).compute())
print(dask.array.nanmax(X_train.partitions[1], 0).compute())
# [ 1.45535379 -0.22753906]
# [25.88824717 25.7793674 ]
Given this investigation, I feel confident closing this issue. Either of the following approaches should be sufficient to avoid this bug:
|
Thanks for the explanation. That makes a lot of sense. |
This issue has been automatically locked since there has not been any recent activity since it was closed. To start a new related discussion, open a new issue at https://github.com/microsoft/LightGBM/issues including a reference to this. |
Description
We are experimenting with distributed learning using lightgbm + dask, and noticed that the predictions could be obviously wrong (crazy numbers). For the reproducible example below, the
r_squared
of the in-sample prediction using lightgbm + dask is -1.5e+56. As comparison, training on local machine without distributed learning generates a reasonable prediction withr_squared
~ 0.01.Reproducible example
Given that a dask
client
has already been created,Output:
Environment info
dask
version: 2021.05.1lightgbm
version: 3.2.1Dataset
train_data.csv
The text was updated successfully, but these errors were encountered: