Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dask] Distributed training sometimes produces very high leaf values #4026

Closed
jose-moralez opened this issue Feb 27, 2021 · 37 comments
Closed

Comments

@jose-moralez
Copy link

Description

Hi. I'm trying to use LightGBM in time series forecasting and it works fine with local models. However, when using dask the predictions sometimes are huge (my time series are in the range [0, 160] and the predictions are 10^38). I've included a reproducible example (I'm sorry it's not that minimal but I couldn't reproduce it with some regular datasets). By running the example below sometimes the values are kind of big but other times they just really explode. With bigger datasets I've seen that the first tree is pretty much the same between local and dask but by the second tree the values in dask become increasingly large (so it may have something to do with the gradients). I haven't been able to track down the issue and was hoping someone could point me in the right direction.

Reproducible example

from itertools import chain
from math import ceil, log10

import dask.dataframe as dd
import lightgbm as lgb
import numpy as np
import pandas as pd
from dask.distributed import Client

def generate_daily_series(n_series: int, 
                          min_length: int = 50, 
                          max_length: int = 500):
    series_lengths = np.random.randint(min_length, max_length+1, n_series)
    total_length = series_lengths.sum()
    n_digits = ceil(log10(n_series))
    
    dates = pd.date_range('2000-01-01', periods=max_length, freq='D').values
    uids = [[f'id_{i:0{n_digits}}'] * serie_length for i, serie_length in enumerate(series_lengths)]
    ds = [dates[-serie_length:] for serie_length in series_lengths]
    y = np.arange(7)[np.arange(total_length) % 7] + np.random.rand(total_length) * 0.5
    series = pd.DataFrame({
        'unique_id': list(chain.from_iterable(uids)),
        'ds': list(chain.from_iterable(ds)),
        'y': y
    })
    series = series.set_index('unique_id')
    return series

def generate_features(df, lags, expanding, rolling,
                      ewm_alpha, date_features):
    df = df.copy()
    for lag in lags:
        df[f'lag_{lag}'] = df.y.shift(lag)
    grouped_lag1 = df.groupby('unique_id')['lag_1']
    for op in expanding:
        df[f'expanding_{op}'] = grouped_lag1.transform(
            lambda y: y.expanding().agg(op)
        )
    for op in rolling:
        df[f'rolling_{op}'] = grouped_lag1.transform(
            lambda y: y.rolling(7).agg(op)
        )
    dates = df.ds.dt
    for feature in date_features:
        df[feature] = getattr(dates, feature)
    df = df.dropna()
    return df

client = Client(n_workers=2, 
                threads_per_worker=2, 
                memory_limit='4 GiB')

# generate data
config = dict(
    lags=[1, 7],
    expanding=['min', 'max', 'mean'],
    rolling=['mean'],
    ewm_alpha=0.1,
    date_features=['dayofweek', 'month', 'year', 'day']  
)
data = generate_daily_series(20, min_length=300, max_length=700)
train = generate_features(data, **config)
X, y = train.drop(columns=['ds', 'y']), train.y

# local regressor
local_reg = lgb.LGBMRegressor().fit(X, y)
local_df = local_reg.booster_.trees_to_dataframe()

# dask regressor
dtrain = dd.from_pandas(train, npartitions=2)
dX, dy = dtrain.drop(columns=['ds', 'y']), dtrain.y
dask_reg = lgb.dask.DaskLGBMRegressor().fit(dX, dy)
dask_df = dask_reg.booster_.trees_to_dataframe()

# Generate plot
ax = dask_df.groupby('tree_index')['value'].mean().plot(label='dask')
local_df.groupby('tree_index')['value'].mean().plot(ax=ax, label='local')
ax.legend();

Environment info

LightGBM version or commit hash: 6356e65

Command(s) you used to install LightGBM

git clone --recursive https://github.com/microsoft/LightGBM.git
cd LightGBM/python-package
python setup.py install

Here are some example plots.
Sometimes the values are just "kind of big".
image

But other times they just really explode.
image

@jmoralez
Copy link
Collaborator

jmoralez commented Mar 3, 2021

FWIW this doesn't happen with dask-lightgbm. I'll try to find where the problem is.

@jmoralez
Copy link
Collaborator

jmoralez commented Mar 3, 2021

This actually seems to be related to lightgbm itself. I just pasted the original dask-lightbm code in python-package/lightgbm/dask.py and found the same issue. The version of lightgbm that I have in my dask-lightgbm environment is 2.3.1, were there any major changes with respect to distributed training after that? I only found #3110 in the relase notes but I just built from the previous commit before that merge and the issue is there.

@guolinke
Copy link
Collaborator

guolinke commented Mar 3, 2021

How much data (rows) you used?
If the data is too small, using distributed learning will cause the poor performance, due to the feature bucketing is done by the local data only.

You can try to construct a dataset on a single node first, via Dataset.construct(), then save it to binary format, Dataset.save_binary(). Then use the binary file for the distributed learning.

@jmoralez
Copy link
Collaborator

jmoralez commented Mar 3, 2021

The example I posted generates 20 time series between 300 and 700 rows each, so on average they would be around 10,000 rows. However the issue persists when having more data, I just changed the 20 to 1,000 and trained on 500,000 rows and reproduced the issue. I'm using two machines, I also noticed that when adding more machines the values on the nodes go even higher (even with big datasets).

I'm trying to track down exactly where this originated because building from 2.3.1 (8364fc3) works fine, however with some of the later changes, (da91c61, d0bec9e) I get this error: lightgbm.basic.LightGBMError: Check failed: best_split_info.left_count > 0 (in case that helps).

@guolinke
Copy link
Collaborator

guolinke commented Mar 3, 2021

@jmoralez I think it is better to fix (use the same) dataset first, constructing it by all rows.
Then we know which part causes the problem.

@jmoralez
Copy link
Collaborator

jmoralez commented Mar 3, 2021

Do you mean using the CLI?

@guolinke
Copy link
Collaborator

guolinke commented Mar 3, 2021

You can try to construct a dataset on a single node first, via Dataset.construct(), then save it to binary format, Dataset.save_binary(). Then use the binary file for the distributed learning.

These are all python interfaces.

@jmoralez
Copy link
Collaborator

jmoralez commented Mar 3, 2021

I've constructed the dataset and saved it to binary format. I'm just confused about

Then use the binary file for the distributed learning.

@guolinke
Copy link
Collaborator

guolinke commented Mar 3, 2021

After you save binary file, you can copy it to the nodes for distributed learning.
Then, you can use the binary file path, like use the normal csv file, to construct the dataset for distributed learning.

This will force the distributed learning to use the same dataset as the single-machine mode.

@jmoralez
Copy link
Collaborator

jmoralez commented Mar 3, 2021

Thank you. I just trained using the CLI on my local machine with two "machines" using the binary file and then loaded the booster in python to see the leaf values and they seem normal. Do you have any suggestions for the dask case?

@guolinke
Copy link
Collaborator

guolinke commented Mar 3, 2021

okay, so the problem indeed happens in Dataset construct.
I am not familiar with dask interface, but I think the python interface lgb.Dataset should support construct by binary file as well.

To ultimately fix this problem, we should add an "accurate" mode, for feature bucketing. This will significantly slow-down the pre-processing stage before training, due to heavy communications.

also cc @shiyu1994 to implement "accurate" mode

The algorithm of "accurate" mode is like:

  1. each node sample part of rows
  2. use all-gather, to gather the sampled rows from all nodes.
  3. each node find bin-mappers for its assigned features like currently do.

@jmoralez
Copy link
Collaborator

jmoralez commented Mar 3, 2021

I should point out that this doesn't happen with all datasets. For example, using sklearn.datasets.load_boston which just has 500 rows performs well. I actually struggled to produce an example, I'm not sure what's particularly special about these kinds of datasets (time series) that impacts the dataset creation. Like I said it used to work in 2.3.1 but I'm not sure why the changes from then on would have a greater impact here.

@guolinke
Copy link
Collaborator

guolinke commented Mar 3, 2021

@jmoralez
Can you also try force_col_wise=true with dask (without pre-constructed dataset), and see what happen?

@jmoralez
Copy link
Collaborator

jmoralez commented Mar 3, 2021

I get high values as well

@jmoralez
Copy link
Collaborator

jmoralez commented Mar 3, 2021

I think it has to do with low variance features. In this example I create for example expanding_min which has values like 0.1, 0.1, 0.1, 0.1, 0.2, ... so it doesn't vary very much and somehow it breaks the model. However, removing low variance features and training only on lags and rolling mean which have more variance provides good results.

@guolinke
Copy link
Collaborator

guolinke commented Mar 3, 2021

@jmoralez did you shuffle the data by rows?

@jmoralez
Copy link
Collaborator

jmoralez commented Mar 3, 2021

I did try it but it doesn't help

@jose-moralez
Copy link
Author

Hi. I've been investigating further and I have a new example. The data is y = x**2 + id + u(0, 1) where x ~ N(0, 1) and id is an integer between 0 and 19. This examples generates differences from the second iteration.

The data looks like this:
image

from itertools import chain

import dask.dataframe as dd
import lightgbm as lgb
import numpy as np
import pandas as pd
from dask.distributed import Client

def create_data(n_series: int = 20, min_length: int = 300, max_length: int = 700, seed: int = 0):
    rng = np.random.RandomState(seed)
    series_lengths = rng.randint(min_length, max_length+1, n_series)
    total_size = series_lengths.sum()
    ids = list(chain.from_iterable([[i] * length for i, length in enumerate(series_lengths)]))    
    x = rng.standard_normal(total_size)
    y = x**2 + rng.rand(total_size)
    df = pd.DataFrame({
        'id': ids,
        'x': x,
        'y': y + ids
    })
    return df

def get_max_abs_leaf_per_tree_index(tree_df):
    tree_df['abs_value'] = tree_df['value'].abs()
    leaves_df = tree_df[lambda x: x.left_child.isnull() & x.right_child.isnull()]
    return leaves_df.groupby('tree_index')['abs_value'].max()    

# set up two "machines"
client = Client(n_workers=2, threads_per_worker=2, memory_limit='4 GiB')

# generate data
df = create_data()
df = df.sample(frac=1)  # shuffle
ddf = dd.from_pandas(df, npartitions=2)
X, y = ddf.drop('y', 1), ddf.y
Xc, yc = X.compute(), y.compute()

# train models and get tree_dfs
dask_model = lgb.DaskLGBMRegressor().fit(X, y)
dask_tree_df = dask_model.booster_.trees_to_dataframe()
dask_max_values_per_tree = get_max_abs_leaf_per_tree_index(dask_tree_df)

local_model = lgb.LGBMRegressor().fit(Xc, yc)
local_tree_df = local_model.booster_.trees_to_dataframe()
local_max_values_per_tree = get_max_abs_leaf_per_tree_index(local_tree_df)

# analyze fit results
values_per_tree = pd.DataFrame({
    'dask': dask_max_values_per_tree,
    'local': local_max_values_per_tree,
})
values_per_tree['dask_minus_local'] = values_per_tree['dask'] - values_per_tree['local']
values_per_tree.head(5).applymap('{:.1f}'.format)

The maximum leaf values per tree look like the following:
image
This is strange because if the dataset creation was the problem I guess the first iteration wouldn't be so similar.

Looking at the nodes in the second tree I see some strange behaviour:

use_cols = ['node_depth', 'node_index', 'parent_index', 'split_gain', 'value', 'count']

tree_index = 1

l = []
for name, tree_df in {'dask': dask_tree_df, 'local': local_tree_df}.items():
    tree = tree_df.loc[lambda x: x.tree_index == tree_index, use_cols].head(20)
    tree.columns = pd.MultiIndex.from_product([[name], tree.columns])
    l.append(tree)

pd.concat(l, axis=1)

image

So for the root node the gain is about the same, but the dask version then finds huge gains (10**6 in node 1-S14 in this example) and has some leaves with values 152 and -306 (the data is in the range [0,30] and the first tree doesn't have negative leaf values).

I'm happy to investigate this further if someone can point me in the right direction.

@jose-moralez
Copy link
Author

jose-moralez commented Mar 10, 2021

I extended the table above to include split_feature and threshold and I realized that in the second tree for the exact same split the distributed version sends 20 samples the wrong way (2331 vs 2351).

image

Edit
Here's the first tree for reference, where the same splits are performed but the counts are ok.
image

@jose-moralez
Copy link
Author

jose-moralez commented Mar 12, 2021

I just realized that even though I shuffle df, since dask partitions on the index the data isn't actually shuffled in the dask dataframe (i.e. the first partition has ids 0-9 and the second one 10-19). We have to do df = df.sample(frac=1).reset_index(drop=True) and doing this removes the issue. However for my use case I split the data by id for preprocessing so shuffling the entire data would be very expensive and if you shuffle the data within each partition the problem persists.

I also found something new, I added the weight column and from the second tree onwards weight and count start to differ for the distributed case.
image

@jmoralez
Copy link
Collaborator

I just tried using tree_learner=voting_parallel and get similar results. I think the problem is in the samples being sent the wrong way, which generates wrong node values and unreal gains. I've looked at the source code but I have no idea where the issue could be.

@guolinke
Copy link
Collaborator

cc @shiyu1994 for possible bugs.

@shiyu1994
Copy link
Collaborator

@jameslamb I'm debugging this. Could you tell me how to make dask regressor print the Warning messages from the C++ code? I seems that setting neither silent=False nor verbose=2 works.

@jameslamb
Copy link
Collaborator

@shiyu1994 sure, I think I can help.

This answer uses a modified version of the code from https://github.com/microsoft/LightGBM/blob/master/examples/python-guide/dask/regression.py (just added silent=False and verbose=2).

  1. start a Python REPL
python
  1. Run the following in it
import dask.array as da
from distributed import Client, LocalCluster
from sklearn.datasets import make_regression
import lightgbm as lgb

cluster = LocalCluster(n_workers=2)
client = Client(cluster)
X, y = make_regression(n_samples=1000, n_features=50)
dX = da.from_array(X, chunks=(100, 50))
dy = da.from_array(y, chunks=(100,))
dask_model = lgb.DaskLGBMRegressor(n_estimators=10, silent=False, verbose=2)
dask_model.fit(dX, dy)

When I did this, I saw the following printed in my Python console after calling fit()

Finding random open ports for workers
[LightGBM] [Warning] num_threads is set=4, n_jobs=-1 will be ignored. Current value: num_threads=4
[LightGBM] [Warning] num_threads is set=4, n_jobs=-1 will be ignored. Current value: num_threads=4
[LightGBM] [Debug] Dataset::GetMultiBinFromAllFeatures: sparse rate 0.000000
[LightGBM] [Debug] init for col-wise cost 0.000004 seconds, init for row-wise cost 0.000503 seconds
[LightGBM] [Warning] Auto-choosing col-wise multi-threading, the overhead of testing was 0.001011 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 12750
[LightGBM] [Info] Number of data points in the train set: 1000, number of used features: 50
[LightGBM] [Info] Start training from score 6.255258
[LightGBM] [Debug] Trained a tree with leaves = 31 and max_depth = 6
[LightGBM] [Debug] Trained a tree with leaves = 31 and max_depth = 7
[LightGBM] [Debug] Trained a tree with leaves = 31 and max_depth = 8
[LightGBM] [Debug] Trained a tree with leaves = 31 and max_depth = 8
[LightGBM] [Debug] Trained a tree with leaves = 31 and max_depth = 7
[LightGBM] [Debug] Trained a tree with leaves = 31 and max_depth = 9
[LightGBM] [Debug] Trained a tree with leaves = 31 and max_depth = 8
[LightGBM] [Debug] Trained a tree with leaves = 31 and max_depth = 9
[LightGBM] [Debug] Trained a tree with leaves = 31 and max_depth = 7
[LightGBM] [Debug] Trained a tree with leaves = 31 and max_depth = 9
DaskLGBMRegressor(n_estimators=10, num_threads=4, silent=False, time_out=120,
                  tree_learner='data', verbose=2)

Are there other specific log messages that you think are missing?

One other thing you can try

You might be able to find more diagnostic information on the Dask dashboard, especially in the "Info" tab that can take you to additional worker logs.

In that same Python REPL, you could get the address for the dashboard by running

print(client.dashboard_link)

Paste that into a browser and you'll be able to see the Dask diagnostic dashboard.

image

But when I did that, I only saw Dask logs and nothing else from LightGBM.

distributed.worker - INFO - Start worker at: tcp://127.0.0.1:33003
distributed.worker - INFO - Listening to: tcp://127.0.0.1:33003
distributed.worker - INFO - dashboard at: 127.0.0.1:34517
distributed.worker - INFO - Waiting to connect to: tcp://127.0.0.1:45311
distributed.worker - INFO - -------------------------------------------------
distributed.worker - INFO - Threads: 4
distributed.worker - INFO - Memory: 16.77 GB
distributed.worker - INFO - Local Directory: /home/jlamb/repos/open-source/LightGBM/dask-worker-space/worker-vlbru721
distributed.worker - INFO - -------------------------------------------------
distributed.worker - INFO - Registered to: tcp://127.0.0.1:45311
distributed.worker - INFO - -------------------------------------------------

@jameslamb
Copy link
Collaborator

ok now I'm weirded out by the changing max_depth, looking at that more closely! That seems unexpected, right? I might open a separate issue about it later, don't want to distract too much from this discussion.

[LightGBM] [Debug] Trained a tree with leaves = 31 and max_depth = 6
[LightGBM] [Debug] Trained a tree with leaves = 31 and max_depth = 7
[LightGBM] [Debug] Trained a tree with leaves = 31 and max_depth = 8
[LightGBM] [Debug] Trained a tree with leaves = 31 and max_depth = 8
[LightGBM] [Debug] Trained a tree with leaves = 31 and max_depth = 7
[LightGBM] [Debug] Trained a tree with leaves = 31 and max_depth = 9
[LightGBM] [Debug] Trained a tree with leaves = 31 and max_depth = 8
[LightGBM] [Debug] Trained a tree with leaves = 31 and max_depth = 9
[LightGBM] [Debug] Trained a tree with leaves = 31 and max_depth = 7
[LightGBM] [Debug] Trained a tree with leaves = 31 and max_depth = 9

@shiyu1994
Copy link
Collaborator

@jameslamb Thanks!

I do find a bug in data parallel training.

When values of the same feature in different machines differ a lot, or even have no overlap, a global optimal split can result in leaves with 0 local data in each machine. For example, in the synthetic example by @jmoralez, the id in machine 0 are in [0, 10] while machine 1 has [11, ...]. Then a split with feature id always results in an empty leaf in one of the machine.

Even under such case, we are expected to provide a stable (perhaps not good enough) result. But the problem is, when a leaf has data count 0, currently we directly skip the histogram construction process, as following.

inline void ConstructHistograms(
const std::vector<int8_t>& is_feature_used,
const data_size_t* data_indices, data_size_t num_data,
const score_t* gradients, const score_t* hessians,
score_t* ordered_gradients, score_t* ordered_hessians,
TrainingShareStates* share_state, hist_t* hist_data) const {
if (num_data <= 0) {
return;
}

This will not cause any problem in a single machine scenario. However, with data distributed training, we need to send the content in histogram buffer to other machines. With num_data <= 0, we exit the ConstructHistograms method directly, without clearing the buffer.

LightGBM/src/io/dataset.cpp

Lines 1192 to 1193 in 6ad3e6e

std::memset(reinterpret_cast<void*>(data_ptr), 0,
num_bin * kHistEntrySize);

So, when training a second tree, the empty leaf (though the leaf is actually not empty, considering the global training data) in a local machine will send the histogram content from the first iteration to other machines, which results in an incorrect global histogram.

That's why the problem always occurs from a second iteration.

Let me open a PR to fix this.

@StrikerRUS
Copy link
Collaborator

@jameslamb

ok now I'm weirded out by the changing max_depth, looking at that more closely!

I believe this is just a logging bug (bad wording):

Log::Debug("Trained a tree with leaves = %d and max_depth = %d", tree->num_leaves(), cur_depth);

@jameslamb
Copy link
Collaborator

oooo ok I'll submit a PR to fix that

@imatiach-msft
Copy link
Contributor

@jmoralez are you seeing output like this from lightgbm:

[LightGBM] [Debug] Trained a tree with leaves = 31 and max_depth = 13
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Debug] Trained a tree with leaves = 1 and max_depth = 1
[LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements
[LightGBM] [Info] Finished linking network in 11.926274 seconds

I'm seeing this for the regressor, where the last tree seems to output very high values

The image you had above where split gain is NaN seems to resonate with what I am seeing, but it seems unlike mmlspark code dask code continue to call update, it ignores the value it returns:

https://github.com/microsoft/LightGBM/blob/master/python-package/lightgbm/dask.py#L90

I created an issue here:
#4178
to understand how to deal with this case, when LightGBM outputs "Stopped training because there are no more leaves that meet the split requirements"

@shiyu1994
Copy link
Collaborator

@imatiach-msft Not sure whether the two are related. Which distributed training strategy are you using?

@imatiach-msft
Copy link
Contributor

@shiyu1994 it's using data_parallel training strategy

@jose-moralez
Copy link
Author

Fixed by #4185

@jose-moralez
Copy link
Author

Hi @shiyu1994. Is it possible that this could still be present when splitting categorical features? I think I'm experiencing the same issue when I use them.

@imatiach-msft
Copy link
Contributor

@jose-moralez are you sure you are using the latest lightgbm on master? This issue was only recently fixed. If so, I would suggest creating a new issue. A similar issue a customer had in mmlspark was resolved with the new change.

@jose-moralez
Copy link
Author

You're right @imatiach-msft, I had installed from source and checked that the issue was solved but I recently rebuilt my env and installed 3.2.1 again. It was my bad, sorry.

@imatiach-msft
Copy link
Contributor

imatiach-msft commented Apr 27, 2021

@jose-moralez ah ok, great to hear it's resolved then - if you still see issues though I would either just reopen this one or create a new one, it's difficult to keep track of closed issues (at least for me when I work on other open source projects as a maintainer)

@github-actions
Copy link

This issue has been automatically locked since there has not been any recent activity since it was closed. To start a new related discussion, open a new issue at https://github.com/microsoft/LightGBM/issues including a reference to this.

@github-actions github-actions bot locked as resolved and limited conversation to collaborators Aug 23, 2023
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

No branches or pull requests

7 participants