Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Auto Scheduler] Upgrade autoscheduler xgboost callback #12144

Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 115 additions & 89 deletions python/tvm/auto_scheduler/cost_model/xgb_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,21 @@
import multiprocessing
import logging
from collections import defaultdict

from typing import Dict
import numpy as np

from tvm.autotvm.tuner.metric import max_curve
from .cost_model import PythonBasedModel
from ..feature import get_per_store_features_from_measure_pairs, get_per_store_features_from_states
from ..measure_record import RecordReader

try:
from xgboost.callback import TrainingCallback # type: ignore
except ImportError:

class TrainingCallback: # type: ignore
pass


xgb = None

logger = logging.getLogger("auto_scheduler")
Expand Down Expand Up @@ -198,7 +205,7 @@ def update(self, inputs, results):
num_boost_round=10000,
obj=pack_sum_square_error,
callbacks=[
custom_callback(
CustomCallback(
stopping_rounds=50,
metric="tr-p-rmse",
fevals=[
Expand Down Expand Up @@ -539,125 +546,144 @@ def feval(preds, labels):
return feval


def custom_callback(
stopping_rounds,
metric,
fevals,
evals=(),
log_file=None,
maximize=False,
verbose_eval=True,
skip_every=2,
):
"""Callback function for xgboost to support multiple custom evaluation functions"""
# pylint: disable=import-outside-toplevel
from xgboost.core import EarlyStopException
from xgboost.callback import _fmt_metric

try:
from xgboost.training import aggcv
except ImportError:
from xgboost.callback import _aggcv as aggcv

state = {}
metric_shortname = metric.split("-")[1]

def init(env):
"""internal function"""
bst = env.model

state["maximize_score"] = maximize
state["best_iteration"] = 0
if maximize:
state["best_score"] = float("-inf")
else:
state["best_score"] = float("inf")
class XGBoostCallback(TrainingCallback):
"""Base class for XGBoost callbacks."""

if bst is not None:
if bst.attr("best_score") is not None:
state["best_score"] = float(bst.attr("best_score"))
state["best_iteration"] = int(bst.attr("best_iteration"))
state["best_msg"] = bst.attr("best_msg")
else:
bst.set_attr(best_iteration=str(state["best_iteration"]))
bst.set_attr(best_score=str(state["best_score"]))
else:
assert env.cvfolds is not None
def __call__(self, env: "xgb.core.CallbackEnv"):
# Compatibility with xgboost < 1.3
return self.after_iteration(env.model, env.iteration, env.evaluation_result_list)

def after_iteration(self, model: "xgb.Booster", epoch: int, evals_log: Dict):
raise NotImplementedError


class CustomCallback(XGBoostCallback):
"""
Callback function for xgboost.
Support custom evaluation function and early-stopping.
"""

def callback(env):
"""internal function"""
if not state:
init(env)
def __init__(
self,
stopping_rounds,
metric,
fevals,
evals=(),
log_file=None,
maximize=False,
verbose_eval=True,
skip_every=2,
):
"""Init function"""
self.stopping_rounds = stopping_rounds
self.metric = metric
self.metric_shortname = metric.split("-")[1]
self.fevals = fevals
self.evals = evals
self.log_file = log_file
self.maximize = maximize
self.verbose_eval = verbose_eval
self.skip_every = skip_every
self.state = {}

bst = env.model
i = env.iteration
cvfolds = env.cvfolds
def after_iteration(self, model: "xgb.Booster", epoch: int, evals_log: Dict):
"""Run after each iteration. Return True when training should stop."""
# pylint:disable = import-outside-toplevel
try:
from xgboost.callback import _fmt_metric # type: ignore
except ImportError:
# Compatibility with xgboost >= 1.6
def _fmt_metric(value, show_stdv=True):
"""format metric string"""
if len(value) == 2:
return f"{value[0]}:{value[1]:.5f}"
if len(value) == 3:
if show_stdv:
return f"{value[0]}:{value[1]:.5f}+{value[2]:.5f}"
return f"{value[0]}:{value[1]:.5f}"
raise ValueError("wrong metric value", value)

##### init state #####
if not self.state:
self.state["maximize_score"] = self.maximize
self.state["best_iteration"] = 0
if self.maximize:
self.state["best_score"] = float("-inf")
else:
self.state["best_score"] = float("inf")

assert model is not None
if model.attr("best_score") is not None:
self.state["best_score"] = float(model.attr("best_score"))
self.state["best_iteration"] = int(model.attr("best_iteration"))
self.state["best_msg"] = model.attr("best_msg")
else:
model.set_attr(best_iteration=str(self.state["best_iteration"]))
model.set_attr(best_score=str(self.state["best_score"]))
res_dict = {}

if i % skip_every == 1:
return
if epoch % self.skip_every == 1:
return False

##### evaluation #####
if cvfolds is not None:
for feval in fevals:
tmp = aggcv([f.eval(i, feval) for f in cvfolds])
for k, mean, std in tmp:
res_dict[k] = [mean, std]
else:
for feval in fevals:
bst_eval = bst.eval_set(evals, i, feval)
res = [x.split(":") for x in bst_eval.split()]
for kv in res[1:]:
res_dict[kv[0]] = [float(kv[1])]
for feval in self.fevals:
bst_eval = model.eval_set(self.evals, epoch, feval)
res = [x.split(":") for x in bst_eval.split()]
for kv in res[1:]:
res_dict[kv[0]] = [float(kv[1])]

eval_res = []
keys = list(res_dict.keys())
keys.sort(key=lambda x: x if metric_shortname not in x else "a" + x)
keys.sort(key=lambda x: x if self.metric_shortname not in x else "a" + x)
for key in keys:
v = res_dict[key]
eval_res.append([key] + v)

##### print eval result #####
if not isinstance(verbose_eval, bool) and verbose_eval and i % verbose_eval == 0:
infos = ["XGB iter: %3d" % i]
if (
not isinstance(self.verbose_eval, bool)
and self.verbose_eval
and epoch % self.verbose_eval == 0
):
infos = ["XGB iter: %3d" % epoch]
for item in eval_res:
if "null" in item[0]:
continue
infos.append("%s: %.6f" % (item[0], item[1]))

logger.debug("\t".join(infos))
if log_file:
with open(log_file, "a") as fout:
if self.log_file:
with open(self.log_file, "a") as fout:
fout.write("\t".join(infos) + "\n")

##### choose score and do early stopping #####
score = None
for item in eval_res:
if item[0] == metric:
if item[0] == self.metric:
score = item[1]
break
assert score is not None

best_score = state["best_score"]
best_iteration = state["best_iteration"]
maximize_score = state["maximize_score"]
best_score = self.state["best_score"]
best_iteration = self.state["best_iteration"]
maximize_score = self.state["maximize_score"]

if (maximize_score and score > best_score) or (not maximize_score and score < best_score):
msg = "[%d] %s" % (env.iteration, "\t".join([_fmt_metric(x) for x in eval_res]))
state["best_msg"] = msg
state["best_score"] = score
state["best_iteration"] = env.iteration
msg = "[%d] %s" % (epoch, "\t".join([_fmt_metric(x) for x in eval_res]))
self.state["best_msg"] = msg
self.state["best_score"] = score
self.state["best_iteration"] = epoch
# save the property to attributes, so they will occur in checkpoint.
if env.model is not None:
env.model.set_attr(
best_score=str(state["best_score"]),
best_iteration=str(state["best_iteration"]),
best_msg=state["best_msg"],
if model is not None:
model.set_attr(
best_score=str(self.state["best_score"]),
best_iteration=str(self.state["best_iteration"]),
best_msg=self.state["best_msg"],
)
elif env.iteration - best_iteration >= stopping_rounds:
best_msg = state["best_msg"]
if verbose_eval and env.rank == 0:
elif epoch - best_iteration >= self.stopping_rounds:
best_msg = self.state["best_msg"]
if self.verbose_eval:
logger.debug("XGB stopped. Best iteration: %s ", best_msg)
raise EarlyStopException(best_iteration)
return True

return callback
return False