Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle missing values in one hot splits. #7934

Merged
merged 4 commits into from
May 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 63 additions & 9 deletions src/tree/hist/evaluate_splits.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,72 @@ class HistEvaluator {
// then - there are no missing values
// else - there are missing values
bool static SplitContainsMissingValues(const GradStats e, const NodeEntry &snode) {
if (e.GetGrad() == snode.stats.GetGrad() &&
e.GetHess() == snode.stats.GetHess()) {
if (e.GetGrad() == snode.stats.GetGrad() && e.GetHess() == snode.stats.GetHess()) {
return false;
} else {
return true;
}
}

bool IsValid(GradStats const &left, GradStats const &right) const {
return left.GetHess() >= param_.min_child_weight && right.GetHess() >= param_.min_child_weight;
}

/**
* \brief Use learned direction with one-hot split. Other implementations (LGB, sklearn)
* create a pseudo-category for missing value but here we just do a complete scan
* to avoid making specialized histogram bin.
*/
void EnumerateOneHot(common::HistogramCuts const &cut, const common::GHistRow &hist,
bst_feature_t fidx, bst_node_t nidx,
TreeEvaluator::SplitEvaluator<TrainParam> const &evaluator,
SplitEntry *p_best) const {
const std::vector<uint32_t> &cut_ptr = cut.Ptrs();
const std::vector<bst_float> &cut_val = cut.Values();

bst_bin_t ibegin = static_cast<bst_bin_t>(cut_ptr[fidx]);
bst_bin_t iend = static_cast<bst_bin_t>(cut_ptr[fidx + 1]);
bst_bin_t n_bins = iend - ibegin;

GradStats left_sum;
GradStats right_sum;
// best split so far
SplitEntry best;

auto f_hist = hist.subspan(cut_ptr[fidx], n_bins);
auto feature_sum = GradStats{
std::accumulate(f_hist.data(), f_hist.data() + f_hist.size(), GradientPairPrecise{})};
GradStats missing;
auto const &parent = snode_[nidx];
missing.SetSubstract(parent.stats, feature_sum);

for (bst_bin_t i = ibegin; i != iend; i += 1) {
auto split_pt = cut_val[i];

// missing on left (treat missing as other categories)
right_sum = GradStats{hist[i]};
left_sum.SetSubstract(parent.stats, right_sum);
if (IsValid(left_sum, right_sum)) {
auto missing_left_chg = static_cast<float>(
evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{left_sum}, GradStats{right_sum}) -
parent.root_gain);
best.Update(missing_left_chg, fidx, split_pt, true, true, left_sum, right_sum);
}

// missing on right (treat missing as chosen category)
left_sum.SetSubstract(left_sum, missing);
right_sum.Add(missing);
if (IsValid(left_sum, right_sum)) {
auto missing_right_chg = static_cast<float>(
evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{left_sum}, GradStats{right_sum}) -
parent.root_gain);
best.Update(missing_right_chg, fidx, split_pt, false, true, left_sum, right_sum);
}
}

p_best->Update(best);
}

// Enumerate/Scan the split values of specific feature
// Returns the sum of gradients corresponding to the data points that contains
// a non-missing value for the particular feature fid.
Expand Down Expand Up @@ -102,9 +160,7 @@ class HistEvaluator {
break;
}
case kOneHot: {
// not-chosen categories go to left
right_sum = GradStats{hist[i]};
left_sum.SetSubstract(parent.stats, right_sum);
std::terminate(); // unreachable
break;
}
case kPart: {
Expand Down Expand Up @@ -151,7 +207,7 @@ class HistEvaluator {
break;
}
case kOneHot: {
split_pt = cut_val[i];
std::terminate(); // unreachable
break;
}
case kPart: {
Expand Down Expand Up @@ -188,7 +244,6 @@ class HistEvaluator {
// Normal, accumulated to left
return left_sum;
case kOneHot:
// Doesn't matter, not accumulating.
return {};
case kPart:
// Accumulated to right due to chosen cats go to right.
Expand Down Expand Up @@ -242,8 +297,7 @@ class HistEvaluator {
if (is_cat) {
auto n_bins = cut_ptrs.at(fidx + 1) - cut_ptrs[fidx];
if (common::UseOneHot(n_bins, param_.max_cat_to_onehot)) {
EnumerateSplit<+1, kOneHot>(cut, {}, histogram, fidx, nidx, evaluator, best);
EnumerateSplit<-1, kOneHot>(cut, {}, histogram, fidx, nidx, evaluator, best);
EnumerateOneHot(cut, histogram, fidx, nidx, evaluator, best);
} else {
std::vector<size_t> sorted_idx(n_bins);
std::iota(sorted_idx.begin(), sorted_idx.end(), 0);
Expand Down
32 changes: 28 additions & 4 deletions tests/python/test_updaters.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,17 +214,19 @@ def test_max_cat(self, tree_method) -> None:
self.run_max_cat(tree_method)

def run_categorical_basic(self, rows, cols, rounds, cats, tree_method):
USE_ONEHOT = np.iinfo(np.int32).max
USE_PART = 1

onehot, label = tm.make_categorical(rows, cols, cats, True)
cat, _ = tm.make_categorical(rows, cols, cats, False)

by_etl_results = {}
by_builtin_results = {}

predictor = "gpu_predictor" if tree_method == "gpu_hist" else None
parameters = {"tree_method": tree_method, "predictor": predictor}
# Use one-hot exclusively
parameters = {
"tree_method": tree_method, "predictor": predictor, "max_cat_to_onehot": 9999
}
parameters["max_cat_to_onehot"] = USE_ONEHOT

m = xgb.DMatrix(onehot, label, enable_categorical=False)
xgb.train(
Expand Down Expand Up @@ -257,7 +259,8 @@ def run_categorical_basic(self, rows, cols, rounds, cats, tree_method):
assert tm.non_increasing(by_builtin_results["Train"]["rmse"])

by_grouping: xgb.callback.TrainingCallback.EvalsLog = {}
parameters["max_cat_to_onehot"] = 1
# switch to partition-based splits
parameters["max_cat_to_onehot"] = USE_PART
parameters["reg_lambda"] = 0
m = xgb.DMatrix(cat, label, enable_categorical=True)
xgb.train(
Expand All @@ -284,6 +287,27 @@ def run_categorical_basic(self, rows, cols, rounds, cats, tree_method):
)
assert tm.non_increasing(by_grouping["Train"]["rmse"]), by_grouping

# test with missing values
cat, label = tm.make_categorical(
n_samples=256, n_features=4, n_categories=8, onehot=False, sparsity=0.5
)
Xy = xgb.DMatrix(cat, label, enable_categorical=True)
evals_result = {}
# Test with onehot splits
parameters["max_cat_to_onehot"] = USE_ONEHOT
booster = xgb.train(
parameters,
Xy,
num_boost_round=16,
evals=[(Xy, "Train")],
evals_result=evals_result
)
assert tm.non_increasing(evals_result["Train"]["rmse"])
y_predt = booster.predict(Xy)

rmse = tm.root_mean_square(label, y_predt)
np.testing.assert_allclose(rmse, evals_result["Train"]["rmse"][-1])

@given(strategies.integers(10, 400), strategies.integers(3, 8),
strategies.integers(1, 2), strategies.integers(4, 7))
@settings(deadline=None, print_blob=True)
Expand Down
15 changes: 14 additions & 1 deletion tests/python/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def get_mq2008(dpath):

@memory.cache
def make_categorical(
n_samples: int, n_features: int, n_categories: int, onehot: bool
n_samples: int, n_features: int, n_categories: int, onehot: bool, sparsity=0.0,
):
import pandas as pd

Expand All @@ -325,6 +325,13 @@ def make_categorical(
for col in df.columns:
df[col] = df[col].cat.set_categories(categories)

if sparsity > 0.0:
for i in range(n_features):
index = rng.randint(low=0, high=n_samples-1, size=int(n_samples * sparsity))
df.iloc[index, i] = np.NaN
assert df.iloc[:, i].isnull().values.any()
assert n_categories == np.unique(df.dtypes[i].categories).size

if onehot:
return pd.get_dummies(df), label
return df, label
Expand Down Expand Up @@ -538,6 +545,12 @@ def eval_error_metric_skl(y_true: np.ndarray, y_score: np.ndarray) -> float:
return np.sum(r)


def root_mean_square(y_true: np.ndarray, y_score: np.ndarray) -> float:
err = y_score - y_true
rmse = np.sqrt(np.dot(err, err) / y_score.size)
return rmse


def softmax(x):
e = np.exp(x)
return e / np.sum(e)
Expand Down