From 8edc397ce164bbecda559ff6ba045f0608a5a3e1 Mon Sep 17 00:00:00 2001 From: jiabeizhao Date: Wed, 20 Jul 2022 15:15:03 +0800 Subject: [PATCH 1/2] upgrade callback function type to xgboost callback.TrainingCallback --- .../auto_scheduler/cost_model/xgb_model.py | 203 ++++++++++-------- 1 file changed, 115 insertions(+), 88 deletions(-) diff --git a/python/tvm/auto_scheduler/cost_model/xgb_model.py b/python/tvm/auto_scheduler/cost_model/xgb_model.py index a4e39b906149..ad18d5dfa8ca 100644 --- a/python/tvm/auto_scheduler/cost_model/xgb_model.py +++ b/python/tvm/auto_scheduler/cost_model/xgb_model.py @@ -22,12 +22,20 @@ from collections import defaultdict import numpy as np - +from typing import Dict from tvm.autotvm.tuner.metric import max_curve from .cost_model import PythonBasedModel from ..feature import get_per_store_features_from_measure_pairs, get_per_store_features_from_states from ..measure_record import RecordReader +try: + from xgboost.callback import TrainingCallback # type: ignore +except ImportError: + + class TrainingCallback: # type: ignore + pass + + xgb = None logger = logging.getLogger("auto_scheduler") @@ -198,7 +206,7 @@ def update(self, inputs, results): num_boost_round=10000, obj=pack_sum_square_error, callbacks=[ - custom_callback( + CustomCallback( stopping_rounds=50, metric="tr-p-rmse", fevals=[ @@ -539,125 +547,144 @@ def feval(preds, labels): return feval -def custom_callback( - stopping_rounds, - metric, - fevals, - evals=(), - log_file=None, - maximize=False, - verbose_eval=True, - skip_every=2, -): - """Callback function for xgboost to support multiple custom evaluation functions""" - # pylint: disable=import-outside-toplevel - from xgboost.core import EarlyStopException - from xgboost.callback import _fmt_metric - - try: - from xgboost.training import aggcv - except ImportError: - from xgboost.callback import _aggcv as aggcv - - state = {} - metric_shortname = metric.split("-")[1] - - def init(env): - """internal function""" - bst = env.model - - state["maximize_score"] = maximize - state["best_iteration"] = 0 - if maximize: - state["best_score"] = float("-inf") - else: - state["best_score"] = float("inf") +class XGBoostCallback(TrainingCallback): + """Base class for XGBoost callbacks.""" - if bst is not None: - if bst.attr("best_score") is not None: - state["best_score"] = float(bst.attr("best_score")) - state["best_iteration"] = int(bst.attr("best_iteration")) - state["best_msg"] = bst.attr("best_msg") - else: - bst.set_attr(best_iteration=str(state["best_iteration"])) - bst.set_attr(best_score=str(state["best_score"])) - else: - assert env.cvfolds is not None + def __call__(self, env: "xgb.core.CallbackEnv"): + # Compatibility with xgboost < 1.3 + return self.after_iteration(env.model, env.iteration, env.evaluation_result_list) - def callback(env): - """internal function""" - if not state: - init(env) + def after_iteration(self, model: "xgb.Booster", epoch: int, evals_log: Dict): + raise NotImplementedError - bst = env.model - i = env.iteration - cvfolds = env.cvfolds +class CustomCallback(XGBoostCallback): + """ + Callback function for xgboost. + Support custom evaluation function and early-stopping. + """ + + def __init__( + self, + stopping_rounds, + metric, + fevals, + evals=(), + log_file=None, + maximize=False, + verbose_eval=True, + skip_every=2, + ): + """Init function""" + self.stopping_rounds = stopping_rounds + self.metric = metric + self.metric_shortname = metric.split("-")[1] + self.fevals = fevals + self.evals = evals + self.log_file = log_file + self.maximize = maximize + self.verbose_eval = verbose_eval + self.skip_every = skip_every + self.state = {} + + def after_iteration(self, model, epoch, _evals_log): + """Run after each iteration. Return True when training should stop.""" + # pylint:disable = import-outside-toplevel + try: + from xgboost.callback import _fmt_metric # type: ignore + except ImportError: + # Compatibility with xgboost >= 1.6 + def _fmt_metric(value, show_stdv=True): + """format metric string""" + if len(value) == 2: + return f"{value[0]}:{value[1]:.5f}" + if len(value) == 3: + if show_stdv: + return f"{value[0]}:{value[1]:.5f}+{value[2]:.5f}" + return f"{value[0]}:{value[1]:.5f}" + raise ValueError("wrong metric value", value) + + ##### init state ##### + if not self.state: + self.state["maximize_score"] = self.maximize + self.state["best_iteration"] = 0 + if self.maximize: + self.state["best_score"] = float("-inf") + else: + self.state["best_score"] = float("inf") + + assert model is not None + if model.attr("best_score") is not None: + self.state["best_score"] = float(model.attr("best_score")) + self.state["best_iteration"] = int(model.attr("best_iteration")) + self.state["best_msg"] = model.attr("best_msg") + else: + model.set_attr(best_iteration=str(self.state["best_iteration"])) + model.set_attr(best_score=str(self.state["best_score"])) res_dict = {} - if i % skip_every == 1: - return + if epoch % self.skip_every == 1: + return False ##### evaluation ##### - if cvfolds is not None: - for feval in fevals: - tmp = aggcv([f.eval(i, feval) for f in cvfolds]) - for k, mean, std in tmp: - res_dict[k] = [mean, std] - else: - for feval in fevals: - bst_eval = bst.eval_set(evals, i, feval) - res = [x.split(":") for x in bst_eval.split()] - for kv in res[1:]: - res_dict[kv[0]] = [float(kv[1])] + for feval in self.fevals: + bst_eval = model.eval_set(self.evals, epoch, feval) + res = [x.split(":") for x in bst_eval.split()] + for kv in res[1:]: + res_dict[kv[0]] = [float(kv[1])] eval_res = [] keys = list(res_dict.keys()) - keys.sort(key=lambda x: x if metric_shortname not in x else "a" + x) + keys.sort(key=lambda x: x if self.metric_shortname not in x else "a" + x) for key in keys: v = res_dict[key] eval_res.append([key] + v) ##### print eval result ##### - if not isinstance(verbose_eval, bool) and verbose_eval and i % verbose_eval == 0: - infos = ["XGB iter: %3d" % i] + if ( + not isinstance(self.verbose_eval, bool) + and self.verbose_eval + and epoch % self.verbose_eval == 0 + ): + infos = ["XGB iter: %3d" % epoch] for item in eval_res: if "null" in item[0]: continue infos.append("%s: %.6f" % (item[0], item[1])) logger.debug("\t".join(infos)) - if log_file: - with open(log_file, "a") as fout: + if self.log_file: + with open(self.log_file, "a") as fout: fout.write("\t".join(infos) + "\n") ##### choose score and do early stopping ##### score = None for item in eval_res: - if item[0] == metric: + if item[0] == self.metric: score = item[1] break assert score is not None - best_score = state["best_score"] - best_iteration = state["best_iteration"] - maximize_score = state["maximize_score"] + best_score = self.state["best_score"] + best_iteration = self.state["best_iteration"] + maximize_score = self.state["maximize_score"] + if (maximize_score and score > best_score) or (not maximize_score and score < best_score): - msg = "[%d] %s" % (env.iteration, "\t".join([_fmt_metric(x) for x in eval_res])) - state["best_msg"] = msg - state["best_score"] = score - state["best_iteration"] = env.iteration + msg = "[%d] %s" % (epoch, "\t".join([_fmt_metric(x) for x in eval_res])) + self.state["best_msg"] = msg + self.state["best_score"] = score + self.state["best_iteration"] = epoch # save the property to attributes, so they will occur in checkpoint. - if env.model is not None: - env.model.set_attr( - best_score=str(state["best_score"]), - best_iteration=str(state["best_iteration"]), - best_msg=state["best_msg"], + if model is not None: + model.set_attr( + best_score=str(self.state["best_score"]), + best_iteration=str(self.state["best_iteration"]), + best_msg=self.state["best_msg"], ) - elif env.iteration - best_iteration >= stopping_rounds: - best_msg = state["best_msg"] - if verbose_eval and env.rank == 0: + elif epoch - best_iteration >= self.stopping_rounds: + best_msg = self.state["best_msg"] + if self.verbose_eval: logger.debug("XGB stopped. Best iteration: %s ", best_msg) - raise EarlyStopException(best_iteration) + return True - return callback + return False From ede9413b698ff5dbe094de99349a9018c76d78cd Mon Sep 17 00:00:00 2001 From: jiabeizhao Date: Fri, 22 Jul 2022 17:30:54 +0800 Subject: [PATCH 2/2] lint --- python/tvm/auto_scheduler/cost_model/xgb_model.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/tvm/auto_scheduler/cost_model/xgb_model.py b/python/tvm/auto_scheduler/cost_model/xgb_model.py index ad18d5dfa8ca..82430cab6a64 100644 --- a/python/tvm/auto_scheduler/cost_model/xgb_model.py +++ b/python/tvm/auto_scheduler/cost_model/xgb_model.py @@ -20,9 +20,8 @@ import multiprocessing import logging from collections import defaultdict - -import numpy as np from typing import Dict +import numpy as np from tvm.autotvm.tuner.metric import max_curve from .cost_model import PythonBasedModel from ..feature import get_per_store_features_from_measure_pairs, get_per_store_features_from_states @@ -587,7 +586,7 @@ def __init__( self.skip_every = skip_every self.state = {} - def after_iteration(self, model, epoch, _evals_log): + def after_iteration(self, model: "xgb.Booster", epoch: int, evals_log: Dict): """Run after each iteration. Return True when training should stop.""" # pylint:disable = import-outside-toplevel try: