From 1205ea923b8bdeb90fb62a65424e2218762554ac Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Thu, 18 Jul 2024 12:23:01 +0900 Subject: [PATCH] enable callback func test with xgboost>=1.6.0 --- .../test_meta_schedule_cost_model.py | 25 +++++-------------- 1 file changed, 6 insertions(+), 19 deletions(-) diff --git a/tests/python/meta_schedule/test_meta_schedule_cost_model.py b/tests/python/meta_schedule/test_meta_schedule_cost_model.py index 0e1b2f64216b..dadedcf601aa 100644 --- a/tests/python/meta_schedule/test_meta_schedule_cost_model.py +++ b/tests/python/meta_schedule/test_meta_schedule_cost_model.py @@ -257,17 +257,6 @@ def test_meta_schedule_xgb_model_reupdate(): model.predict(TuneContext(), [_dummy_candidate() for i in range(predict_sample_count)]) -def xgb_version_check(): - - # pylint: disable=import-outside-toplevel - import xgboost as xgb - from packaging import version - - # pylint: enable=import-outside-toplevel - return version.parse(xgb.__version__) >= version.parse("1.6.0") - - -@unittest.skipIf(xgb_version_check(), "test not supported for xgboost version after 1.6.0") def test_meta_schedule_xgb_model_callback_as_function(): # pylint: disable=import-outside-toplevel from itertools import chain as itertools_chain @@ -330,14 +319,12 @@ def avg_peak_score(ys_pred: np.ndarray, d_train1: "xgb.DMatrix"): # type: ignor num_boost_round=10000, obj=obj, callbacks=[ - partial( - _get_custom_call_back( - early_stopping_rounds=model.early_stopping_rounds, - verbose_eval=model.verbose_eval, - fevals=[rmse, avg_peak_score], - evals=[(d_train.dmatrix, "tr")], - cvfolds=None, - ) + _get_custom_call_back( + early_stopping_rounds=model.early_stopping_rounds, + verbose_eval=model.verbose_eval, + fevals=[rmse, avg_peak_score], + evals=[(d_train.dmatrix, "tr")], + cvfolds=None, ) ], )