Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor split evaluator using device-wide Scan primitive #7197

Open
wants to merge 29 commits into
base: master
Choose a base branch
from

Conversation

hcho3
Copy link
Collaborator

@hcho3 hcho3 commented Aug 28, 2021

Instead of performing per-block scan, perform a segmented scan using the entire GPU. We simulate segmented scan using a fancy scan operator. This refactor will help balance work and raise GPU utilization.

Original idea by @trivialfis

WIP, as some tests are failing.

src/tree/gpu_hist/evaluate_splits.cu Outdated Show resolved Hide resolved
src/tree/gpu_hist/evaluate_splits.cu Outdated Show resolved Hide resolved
src/tree/gpu_hist/evaluate_splits.cu Outdated Show resolved Hide resolved
@trivialfis trivialfis mentioned this pull request Sep 9, 2021
67 tasks
@hcho3
Copy link
Collaborator Author

hcho3 commented Sep 14, 2021

Can someone help me fix MultiClassesSerializationTest.GpuHist? All the other tests in gtest are passing.

@hcho3 hcho3 marked this pull request as ready for review September 14, 2021 04:19
@hcho3 hcho3 changed the title [WIP] Refactor split evaluator Refactor split evaluator Sep 14, 2021
@trivialfis
Copy link
Member

Let me take a look later today.

@hcho3
Copy link
Collaborator Author

hcho3 commented Sep 14, 2021

Failing test from the JVM package:

- XGBoostRegressor should make correct predictions after upstream random sort *** FAILED ***
  34 was not less than 3.4000000000000004 (XGBoostRegressorSuite.scala:69)

@trivialfis
Copy link
Member

That's a test saying the result should have a good match with the exact tree method. I removed the constraint in #7214 to get the test pass but haven't investigated the cause yet.

@trivialfis
Copy link
Member

Reproduced the failing gtest. I think there are still some issues with this PR, compared the result between master and this PR with this simple script:

import numpy as np
import xgboost
import ctypes

np.random.seed(1994)

kRows = 2048
kCols = 100

X = np.random.randn(kRows, kCols)
y = np.random.randn(kRows)

dtrain = xgboost.DMatrix(X, y)

bst = xgboost.train(
    {"tree_method": "gpu_hist"},
    dtrain,
    num_boost_round=10,
    evals=[(dtrain, "validation_0")],
)

master

[0] validation_0-rmse:1.01426
[1] validation_0-rmse:0.93217
[2] validation_0-rmse:0.89363
[3] validation_0-rmse:0.85061
[4] validation_0-rmse:0.81075
[5] validation_0-rmse:0.77508
[6] validation_0-rmse:0.73942
[7] validation_0-rmse:0.71745
[8] validation_0-rmse:0.69094
[9] validation_0-rmse:0.67619

rework-evaluation

[0] validation_0-rmse:1.01243
[1] validation_0-rmse:0.93273
[2] validation_0-rmse:0.90907
[3] validation_0-rmse:0.89448
[4] validation_0-rmse:0.94705
[5] validation_0-rmse:0.95018
[6] validation_0-rmse:0.93193
[7] validation_0-rmse:0.92086
[8] validation_0-rmse:0.96244
[9] validation_0-rmse:0.99594

@hcho3
Copy link
Collaborator Author

hcho3 commented Sep 15, 2021

@trivialfis Thanks for the simple Python script. I managed to fix the bug that was failing the Python script. (The bug had to do with setting the threshold fvalue incorrectly when performing the backward pass.) The bug fix also fixed the failing test in the JVM package.

Unfortunately, the gtest MultiClassesSerializationTest.GpuHist is still failing. I am currently looking at it now.

@hcho3 hcho3 changed the title Refactor split evaluator Refactor split evaluator using device-wide Scan primitive Sep 15, 2021
@hcho3
Copy link
Collaborator Author

hcho3 commented Sep 15, 2021

EDIT. The bug is fixed in the latest commit.

Another small counterexample, where categorical data don't yield the same result as its one-hot-encoded version:

import pandas as pd
import numpy as np
import xgboost as xgb

def make_categorical(
    n_samples: int, n_features: int, n_categories: int, onehot: bool
):

    rng = np.random.RandomState(1994)

    pd_dict = {}
    for i in range(n_features + 1):
        c = rng.randint(low=0, high=n_categories, size=n_samples)
        pd_dict[str(i)] = pd.Series(c, dtype=np.int64)

    df = pd.DataFrame(pd_dict)
    label = df.iloc[:, 0]
    df = df.iloc[:, 1:]
    for i in range(0, n_features):
        label += df.iloc[:, i]
    label += 1

    df = df.astype("category")
    categories = np.arange(0, n_categories)
    for col in df.columns:
        df[col] = df[col].cat.set_categories(categories)

    if onehot:
        return pd.get_dummies(df), label
    return df, label

rows = 10
cols = 3
rounds = 2
cats = 4

onehot, label = make_categorical(rows, cols, cats, True)
cat, _ = make_categorical(rows, cols, cats, False)

parameters = {"tree_method": "gpu_hist", "predictor": "gpu_predictor"}

m = xgb.DMatrix(onehot, label, enable_categorical=False)
xgb.train(
    parameters,
    m,
    num_boost_round=rounds,
    evals=[(m, "Train")]
)

m = xgb.DMatrix(cat, label, enable_categorical=True)
xgb.train(
    parameters,
    m,
    num_boost_round=rounds,
    evals=[(m, "Train")]
)

Output:

[0]	Train-rmse:5.24927
[1]	Train-rmse:4.13486
[0]	Train-rmse:5.24927
[1]	Train-rmse:3.97093

The correct output is as follows, using the code from the master branch:

[0]	Train-rmse:5.24927
[1]	Train-rmse:4.13486
[0]	Train-rmse:5.24927
[1]	Train-rmse:4.13486

@hcho3
Copy link
Collaborator Author

hcho3 commented Sep 15, 2021

I managed to fix all failing tests in gtest. Fingers crossed.

@hcho3
Copy link
Collaborator Author

hcho3 commented Sep 15, 2021

EDIT. The bug is fixed in the latest commit.

Nope, another failing test with the custom objective:

_______________________________________________ TestGPUBasicModels.test_custom_objective ________________________________________________

self = <test_gpu_basic_models.TestGPUBasicModels object at 0x7f428b0efc90>

    def test_custom_objective(self):
>       self.cpu_test_bm.run_custom_objective("gpu_hist")

tests/python-gpu/test_gpu_basic_models.py:41: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <test_basic_models.TestModels object at 0x7f42cd0e7cd0>, tree_method = 'gpu_hist'

    def run_custom_objective(self, tree_method=None):
        param = {
            'max_depth': 2,
            'eta': 1,
            'objective': 'reg:logistic',
            "tree_method": tree_method
        }
        watchlist = [(dtest, 'eval'), (dtrain, 'train')]
        num_round = 10
    
        def logregobj(preds, dtrain):
            labels = dtrain.get_label()
            preds = 1.0 / (1.0 + np.exp(-preds))
            grad = preds - labels
            hess = preds * (1.0 - preds)
            return grad, hess
    
        def evalerror(preds, dtrain):
            labels = dtrain.get_label()
            preds = 1.0 / (1.0 + np.exp(-preds))
            return 'error', float(sum(labels != (preds > 0.5))) / len(labels)
    
        # test custom_objective in training
        bst = xgb.train(param, dtrain, num_round, watchlist, obj=logregobj,
                        feval=evalerror)
        assert isinstance(bst, xgb.core.Booster)
        preds = bst.predict(dtest)
        labels = dtest.get_label()
        err = sum(1 for i in range(len(preds))
                  if int(preds[i] > 0.5) != labels[i]) / float(len(preds))
>       assert err < 0.1
E       assert 0.25822470515207946 < 0.1

tests/python/test_basic_models.py:171: AssertionError

Comment on lines +35 to +41
if (kv.first == "default_left" || kv.first == "split_conditions") {
auto const& l_arr = get<Array const>(l_obj.at(kv.first));
auto const& r_arr = get<Array const>(r_obj.at(kv.first));
ASSERT_EQ(l_arr.size(), r_arr.size());
} else {
CompareJSON(l_obj.at(kv.first), r_obj.at(kv.first));
}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change allows two models to differ in default_left and split_conditions but still require them to produce idential loss_chg. The relaxation is necessary because the test produced two split candidates with identical loss_chg and findex but different default_left and split_conditions. The given scenario occurred because there was no missing values for that particular feature.

@hcho3
Copy link
Collaborator Author

hcho3 commented Sep 23, 2021

All gtests are now fixed.

One last test to fix, this time from multi-GPU side:

____________________ TestDistributedGPU.test_early_stopping ____________________
    @pytest.mark.skipif(**tm.no_cupy())
    @pytest.mark.skipif(**tm.no_dask())
    @pytest.mark.skipif(**tm.no_dask_cuda())
    def test_early_stopping(self, local_cuda_cluster: LocalCUDACluster) -> None:
        from sklearn.datasets import load_breast_cancer
        with Client(local_cuda_cluster) as client:
            X, y = load_breast_cancer(return_X_y=True)
            X, y = da.from_array(X), da.from_array(y)
    
            m = dxgb.DaskDMatrix(client, X, y)
    
            valid = dxgb.DaskDMatrix(client, X, y)
            early_stopping_rounds = 5
            booster = dxgb.train(client, {'objective': 'binary:logistic',
                                          'eval_metric': 'error',
                                          'tree_method': 'gpu_hist'}, m,
                                 evals=[(valid, 'Valid')],
                                 num_boost_round=1000,
                                 early_stopping_rounds=early_stopping_rounds)[
                                     'booster']
            assert hasattr(booster, 'best_score')
            dump = booster.get_dump(dump_format='json')
            assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
    
            valid_X = X
            valid_y = y
            cls = dxgb.DaskXGBClassifier(objective='binary:logistic',
                                         tree_method='gpu_hist',
                                         n_estimators=100)
            cls.client = client
            cls.fit(X, y, early_stopping_rounds=early_stopping_rounds,
                    eval_set=[(valid_X, valid_y)])
            booster = cls.get_booster()
            dump = booster.get_dump(dump_format='json')
>           assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
E           assert 1 == 6
E             +1
E             -6

tests/python-gpu/test_gpu_with_dask.py:388: AssertionError

@hcho3
Copy link
Collaborator Author

hcho3 commented Sep 23, 2021

I believe I fixed all the tests that I've seen to fail so far. Hopefully no more issues. Fingers crossed.

@hcho3
Copy link
Collaborator Author

hcho3 commented Sep 23, 2021

All tests have passed. cc @trivialfis @RAMitchell

*/
dh::device_vector<DeviceSplitCandidate> out_reduce(3);
GPUTrainingParam param = left.param;
thrust::reduce_by_key(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note. We may want to use cub::DeviceSegmentedReduce instead.

@codecov-commenter
Copy link

codecov-commenter commented Sep 23, 2021

Codecov Report

Merging #7197 (d8b7d75) into master (475fd1a) will not change coverage.
The diff coverage is n/a.

Impacted file tree graph

@@           Coverage Diff           @@
##           master    #7197   +/-   ##
=======================================
  Coverage   82.76%   82.76%           
=======================================
  Files          13       13           
  Lines        4061     4061           
=======================================
  Hits         3361     3361           
  Misses        700      700           

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 475fd1a...d8b7d75. Read the comment docs.

Copy link
Member

@RAMitchell RAMitchell left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some minor improvements could be made around unreadable if statements, but mostly looks good.

As discussed, we will need performance analysis before merging.

TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
EvaluateSplitInputs<GradientSumT> left,
EvaluateSplitInputs<GradientSumT> right) {
auto l_n_features = left.feature_segments.empty() ? 0 : left.feature_segments.size() - 1;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In what situations do we get empty feature segments?

Copy link
Collaborator Author

@hcho3 hcho3 Sep 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When left parameter is set to {}. This case is handled in a short if block titled trivial split

l.Update(r, param);
return l;
});
if (right.gradient_histogram.empty()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this edge case occur?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Root node.

});
/**
* Perform segmented reduce to find the best split candidate per node.
* Note that there will be THREE segments:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are there three segments? It looks like you're only using two output values.

Copy link
Collaborator Author

@hcho3 hcho3 Sep 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because we first generate candidate splits with a forward pass over the gradient histogram, in the order of ascending node_idx (0 or 1), and then generate candidate splits with a backward pass over the histgoram, in the order of descending node_idx. So the list of candidate splits generated will consist of three segments

[segment for left child node] [segment for right child node] [segment for left child node]

So after performing a segmented reduce, there is a small kernel launch to copy out_reduce[3] to out_splits[2].

ret.forward = entry.forward;
ret.gpair = split_input.gradient_histogram[entry.hist_idx];
ret.parent_sum = split_input.parent_sum;
if (((entry.node_idx == 0) &&
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This section of code is hard to understand and could be improved.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will create a small utility function that checks whether feature_set contains ret.findex.

// Segmented Scan
return rhs;
}
if (((lhs.node_idx == 0) &&
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

EvaluateSplits(out_split, evaluator, input, {});
ReduceElem<GradientSumT>
__device__ ReduceValueOp<GradientSumT>::operator() (ScanElem<GradientSumT> e) {
ReduceElem<GradientSumT> ret;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can readability be improved here?

Copy link
Collaborator Author

@hcho3 hcho3 Sep 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you be more specific? All this function does is to initialize the members of ReduceElem. Do you find the computation of left_sum and right_sum confusing? Any suggestion is appreciated.

Copy link
Member

@trivialfis trivialfis left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initial review. Will run some benchmarks.

auto l_n_features = left.feature_segments.empty() ? 0 : left.feature_segments.size() - 1;
auto r_n_features = right.feature_segments.empty() ? 0 : right.feature_segments.size() - 1;
if (!(r_n_features == 0 || l_n_features == r_n_features)) {
throw std::runtime_error("Invariant violated");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LOG(FATAL) ?

* [segment for left child node] [segment for right child node] [segment for left child node]
* This is due to how we perform forward and backward passes over the gradient histogram.
*/
dh::device_vector<DeviceSplitCandidate> out_reduce(3);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe use dh::TemporaryArray, which uses caching allocator.

return l;
});
if (right.gradient_histogram.empty()) {
dh::LaunchN(1, [out_reduce = dh::ToSpan(out_reduce), out_splits]__device__(std::size_t) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can fuse this kernel with the one in here

dh::LaunchN(2, [=] __device__(size_t idx) {

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe. It will make it harder to reason about EvaluateSplits() function however.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From recent profiling results, small kernels have a significant impact when we grow deeper trees. But I agree that it is premature optimization if we fuse them now. Let's see the result first. ;-)

auto l_n_features = left.feature_segments.empty() ? 0 : left.feature_segments.size() - 1;
auto r_n_features = right.feature_segments.empty() ? 0 : right.feature_segments.size() - 1;
if (!(r_n_features == 0 || l_n_features == r_n_features)) {
throw std::runtime_error("Invariant violated");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LOG(FATAL) ?

// size <= size < size * 2
return thrust::make_tuple(size * 2 - 1 - i, false);
} else {
return thrust::make_tuple(static_cast<uint32_t>(0), false);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this happen?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should not, I put it there just in case.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can put a SPAN_CHECK here as runtime assert.


template <typename GradientSumT>
struct ScanElem {
uint32_t node_idx;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bst_node_t.

GradientSumT partial_sum{0.0, 0.0};
GradientSumT parent_sum{0.0, 0.0};
float loss_chg{std::numeric_limits<float>::lowest()};
int32_t findex{-1};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bst_note_t.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants