-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Cover approx tree method for categorical data tests. (#7569)
* Add tree to df tests. * Add plotting tests. * Add histogram tests.
- Loading branch information
1 parent
465dc63
commit d6ea5cc
Showing
4 changed files
with
55 additions
and
45 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,25 +1,14 @@ | ||
import sys | ||
import pytest | ||
import xgboost as xgb | ||
|
||
sys.path.append("tests/python") | ||
import testing as tm | ||
from test_parse_tree import TestTreesToDataFrame | ||
|
||
|
||
def test_tree_to_df_categorical(): | ||
X, y = tm.make_categorical(100, 10, 31, False) | ||
Xy = xgb.DMatrix(X, y, enable_categorical=True) | ||
booster = xgb.train({"tree_method": "gpu_hist"}, Xy, num_boost_round=10) | ||
df = booster.trees_to_dataframe() | ||
for _, x in df.iterrows(): | ||
if x["Feature"] != "Leaf": | ||
assert len(x["Category"]) == 1 | ||
cputest = TestTreesToDataFrame() | ||
cputest.run_tree_to_df_categorical("gpu_hist") | ||
|
||
|
||
def test_split_value_histograms(): | ||
X, y = tm.make_categorical(1000, 10, 13, False) | ||
reg = xgb.XGBRegressor(tree_method="gpu_hist", enable_categorical=True) | ||
reg.fit(X, y) | ||
|
||
with pytest.raises(ValueError, match="doesn't"): | ||
reg.get_booster().get_split_value_histogram("3", bins=5) | ||
cputest = TestTreesToDataFrame() | ||
cputest.run_split_value_histograms("gpu_hist") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,40 +1,17 @@ | ||
import sys | ||
import xgboost as xgb | ||
import pytest | ||
import json | ||
|
||
sys.path.append("tests/python") | ||
import testing as tm | ||
|
||
try: | ||
import matplotlib | ||
|
||
matplotlib.use("Agg") | ||
from matplotlib.axes import Axes | ||
from graphviz import Source | ||
except ImportError: | ||
pass | ||
import test_plotting as tp | ||
|
||
|
||
pytestmark = pytest.mark.skipif(**tm.no_multiple(tm.no_matplotlib(), tm.no_graphviz())) | ||
|
||
|
||
class TestPlotting: | ||
cputest = tp.TestPlotting() | ||
|
||
@pytest.mark.skipif(**tm.no_pandas()) | ||
def test_categorical(self): | ||
X, y = tm.make_categorical(1000, 31, 19, onehot=False) | ||
reg = xgb.XGBRegressor( | ||
enable_categorical=True, n_estimators=10, tree_method="gpu_hist" | ||
) | ||
reg.fit(X, y) | ||
trees = reg.get_booster().get_dump(dump_format="json") | ||
for tree in trees: | ||
j_tree = json.loads(tree) | ||
assert "leaf" in j_tree.keys() or isinstance( | ||
j_tree["split_condition"], list | ||
) | ||
|
||
graph = xgb.to_graphviz(reg, num_trees=len(j_tree) - 1) | ||
assert isinstance(graph, Source) | ||
ax = xgb.plot_tree(reg, num_trees=len(j_tree) - 1) | ||
assert isinstance(ax, Axes) | ||
self.cputest.run_categorical("gpu_hist") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters