diff --git a/python-package/xgboost/training.py b/python-package/xgboost/training.py index 34ad027d1e91..b7586e537f6e 100644 --- a/python-package/xgboost/training.py +++ b/python-package/xgboost/training.py @@ -142,9 +142,7 @@ def _train_internal(params, dtrain, ) else: raise ValueError(f'Unknown booster: {booster}') - num_groups = int(config['learner']['learner_model_param']['num_class']) - num_groups = 1 if num_groups == 0 else num_groups - bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree * num_groups + bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree # Copy to serialise and unserialise booster to reset state and free # training memory @@ -184,9 +182,10 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, If there's more than one metric in the **eval_metric** parameter given in **params**, the last metric will be used for early stopping. If early stopping occurs, the model will have three additional fields: - ``bst.best_score``, ``bst.best_iteration`` and ``bst.best_ntree_limit``. - (Use ``bst.best_ntree_limit`` to get the correct value if - ``num_parallel_tree`` and/or ``num_class`` appears in the parameters) + ``bst.best_score``, ``bst.best_iteration`` and ``bst.best_ntree_limit``. Use + ``bst.best_ntree_limit`` to get the correct value if ``num_parallel_tree`` and/or + ``num_class`` appears in the parameters. ``best_ntree_limit`` is the result of + ``num_parallel_tree * best_iteration``. evals_result: dict This dictionary stores the evaluation results of all the items in watchlist. diff --git a/tests/python/test_predict.py b/tests/python/test_predict.py index a44eea916222..ef719bd47044 100644 --- a/tests/python/test_predict.py +++ b/tests/python/test_predict.py @@ -33,9 +33,15 @@ def run_predict_leaf(predictor): y = rng.randint(low=0, high=classes, size=rows) m = xgb.DMatrix(X, y) booster = xgb.train( - {'num_parallel_tree': num_parallel_tree, 'num_class': classes, - 'predictor': predictor, 'tree_method': 'hist'}, m, - num_boost_round=num_boost_round) + { + "num_parallel_tree": num_parallel_tree, + "num_class": classes, + "predictor": predictor, + "tree_method": "hist", + }, + m, + num_boost_round=num_boost_round, + ) empty = xgb.DMatrix(np.ones(shape=(0, cols))) empty_leaf = booster.predict(empty, pred_leaf=True) @@ -52,12 +58,19 @@ def run_predict_leaf(predictor): end = classes * num_parallel_tree * (j + 1) layer = row[start: end] for c in range(classes): - tree_group = layer[c * num_parallel_tree: - (c+1) * num_parallel_tree] + tree_group = layer[c * num_parallel_tree: (c + 1) * num_parallel_tree] assert tree_group.shape[0] == num_parallel_tree # no subsampling so tree in same forest should output same # leaf. assert np.all(tree_group == tree_group[0]) + + ntree_limit = 2 + sliced = booster.predict( + m, pred_leaf=True, ntree_limit=num_parallel_tree * ntree_limit + ) + first = sliced[0, ...] + + assert first.shape[0] == classes * num_parallel_tree * ntree_limit return leaf diff --git a/tests/python/test_training_continuation.py b/tests/python/test_training_continuation.py index 2c4e577d2316..9990ca61b05a 100644 --- a/tests/python/test_training_continuation.py +++ b/tests/python/test_training_continuation.py @@ -123,13 +123,13 @@ def run_training_continuation(self, xgb_params_01, xgb_params_02, gbdt_05 = xgb.train(xgb_params_03, dtrain_5class, num_boost_round=7) assert gbdt_05.best_ntree_limit == ( - gbdt_05.best_iteration + 1) * self.num_parallel_tree * 5 + gbdt_05.best_iteration + 1) * self.num_parallel_tree gbdt_05 = xgb.train(xgb_params_03, dtrain_5class, num_boost_round=3, xgb_model=gbdt_05) assert gbdt_05.best_ntree_limit == ( - gbdt_05.best_iteration + 1) * self.num_parallel_tree * 5 + gbdt_05.best_iteration + 1) * self.num_parallel_tree res1 = gbdt_05.predict(dtrain_5class) res2 = gbdt_05.predict(dtrain_5class, diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index d4d121f12283..d2c90fb71b91 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -92,7 +92,7 @@ def train(booster, forest): ) if forest: - assert cls.best_ntree_limit == rounds * forest * cls.n_classes_ + assert cls.best_ntree_limit == rounds * forest else: assert cls.best_ntree_limit == 0