From 959bca5855b1b20391bb36ec1aef1ab8e8d6ccf3 Mon Sep 17 00:00:00 2001 From: YJ Shi Date: Tue, 19 Jul 2022 16:25:55 -0700 Subject: [PATCH 01/16] update the custom callback function of xgboost --- .../tvm/meta_schedule/cost_model/xgb_model.py | 184 ++++++++++++++++-- 1 file changed, 172 insertions(+), 12 deletions(-) diff --git a/python/tvm/meta_schedule/cost_model/xgb_model.py b/python/tvm/meta_schedule/cost_model/xgb_model.py index 8de034758b4b..6954047f5d3f 100644 --- a/python/tvm/meta_schedule/cost_model/xgb_model.py +++ b/python/tvm/meta_schedule/cost_model/xgb_model.py @@ -22,7 +22,7 @@ import tempfile from collections import OrderedDict from itertools import chain as itertools_chain -from typing import TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, Optional, Tuple +from typing import TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple import numpy as np # type: ignore @@ -35,6 +35,14 @@ from ..utils import cpu_count, derived_object, shash2hex from .metric import max_curve +try: + from xgboost.callback import TrainingCallback +except ImportError: + + class TrainingCallback: + pass + + if TYPE_CHECKING: import xgboost as xgb # type: ignore @@ -573,22 +581,19 @@ def rmse(ys_pred: np.ndarray, d_train: "xgb.DMatrix"): # type: ignore # pylint: def avg_peak_score(ys_pred: np.ndarray, d_train: "xgb.DMatrix"): # type: ignore # pylint: disable = unused-argument return self.d_train.average_peak_score(ys_pred, self.average_peak_n) + xgb_custom_callback = XGBoostCustomCallback( + early_stopping_rounds=self.early_stopping_rounds, + verbose_eval=self.verbose_eval, + fevals=[rmse, avg_peak_score], + evals=[(self.d_train.dmatrix, "tr")], + cvfolds=None, + ) self.booster = xgb.train( self.config.to_dict(), self.d_train.dmatrix, num_boost_round=10000, obj=obj, - callbacks=[ - custom_callback( - early_stopping_rounds=self.early_stopping_rounds, - verbose_eval=self.verbose_eval, - fevals=[ - rmse, - avg_peak_score, - ], - evals=[(self.d_train.dmatrix, "tr")], - ) - ], + callbacks=[xgb_custom_callback], ) del self.d_train @@ -763,3 +768,158 @@ def callback(env: "xgb.core.CallbackEnv"): raise EarlyStopException(best_iteration) return callback + + +class XGBoostCallback(TrainingCallback): + """Base class for XGBoost callbacks.""" + + def __call__(self, env: "xgb.core.CallbackEnv"): + """Compatibility with xgboost<1.3""" + return self.after_iteration(env.model, env.iteration, env.evaluation_result_list) + + def after_iteration(self, model: "xgb.Booster", epoch: int, evals_log: Dict): + raise NotImplementedError + + +class XGBoostCustomCallback(XGBoostCallback): + """Custom callback class for xgboost to support multiple custom evaluation functions""" + + def __init__( + self, + early_stopping_rounds: int, + verbose_eval: int, + fevals: List[Callable], + evals: List[Tuple["xgb.DMatrix", str]], + focused_metric: str = "tr-p-rmse", + cvfolds: Sequence["xgb.training.CVPack"] = None, + ): + self.early_stopping_rounds = early_stopping_rounds + self.verbose_eval = verbose_eval + self.fevals = fevals + self.evals = evals + self.state: Dict[str, Any] = {} + self.focused_metric = focused_metric + self.sort_key = make_metric_sorter(focused_metric=focused_metric) + self.cvfolds = cvfolds + if cvfolds is not None: + self.aggregated_cv = None + + def init(self, model: "xgb.Booster"): + booster: "xgb.Booster" = model + self.state["best_iteration"] = 0 + self.state["best_score"] = float("inf") + if booster is None: + assert self.cvfolds is not None + return + if booster.attr("best_score") is not None: + self.state["best_score"] = float(booster.attr("best_score")) + self.state["best_iteration"] = int(booster.attr("best_iteration")) + self.state["best_msg"] = booster.attr("best_msg") + else: + booster.set_attr(best_iteration=str(self.state["best_iteration"])) + booster.set_attr(best_score=str(self.state["best_score"])) + + def after_iteration(self, model: "xgb.Booster", epoch: int, evals_log: Dict): + try: + from xgboost.callback import _fmt_metric # type: ignore + except ImportError: + """Compatibility with xgboost>=1.6""" + + def _fmt_metric(value, show_stdv=True): + if len(value) == 2: + return f"{value[0]}:{value[1]:.5f}" + if len(value) == 3: + if show_stdv: + return f"{value[0]}:{value[1]:.5f}+{value[2]:.5f}" + return f"{value[0]}:{value[1]:.5f}" + raise ValueError("wrong metric value", value) + + from xgboost import rabit # type: ignore + + try: + from xgboost.training import aggcv # type: ignore + except ImportError: + from xgboost.callback import _aggcv as aggcv # type: ignore + if not self.state: + self.init(model) + booster: xgb.Booster = model + iteration: int = epoch + cvfolds: List[xgb.training.CVPack] = self.cvfolds + ##### Evaluation ##### + # `eval_result` is a list of (key, score) + eval_result: List[Tuple[str, float]] = [] + if cvfolds is None: + eval_result = list( + itertools_chain.from_iterable( + [ + (key, float(value)) + for key, value in map( + lambda x: x.split(":"), + booster.eval_set( + evals=self.evals, + iteration=iteration, + feval=feval, + ).split()[1:], + ) + ] + for feval in self.fevals + ) + ) + else: + eval_result = list( + itertools_chain.from_iterable( + [ + (key, score) + for key, score, _std in aggcv( + fold.eval( + iteration=iteration, + feval=feval, + ) + for fold in cvfolds + ) + ] + for feval in self.fevals + ) + ) + eval_result = list(eval_result) + eval_result.sort(key=self.sort_key) + + ##### Print eval result ##### + if self.verbose_eval and iteration % self.verbose_eval == 0: + info = [] + for key, score in eval_result: + if "null" not in key: + info.append(f"{key}: {score:.6f}") + logger.debug("XGB iter %3d: %s", iteration, "\t".join(info)) + + ##### Choose score and do early stopping ##### + score = None + for key, _score in eval_result: + if key == self.focused_metric: + score = _score + break + assert score is not None + + best_score = self.state["best_score"] + best_iteration = self.state["best_iteration"] + if score < best_score: + tab = "\t" # to work with f-string + msg = f"[{epoch}] {tab.join([_fmt_metric(x) for x in eval_result])}" + self.state["best_msg"] = msg + self.state["best_score"] = score + self.state["best_iteration"] = epoch + # save the property to attributes, so they will occur in checkpoint. + if model is not None: + model.set_attr( + best_score=str(self.state["best_score"]), + best_iteration=str(self.state["best_iteration"]), + best_msg=self.state["best_msg"], + ) + elif epoch - best_iteration >= self.early_stopping_rounds: + best_msg = self.state["best_msg"] + + if self.verbose_eval and rabit.get_rank() == 0: + logger.debug("XGB stopped. Best iteration: %s ", best_msg) + return True # instead of raising EarlyStopException, returning True to end the training + # False to indicate training should not stop. + return False From da8a1d9edd02bc6903cc6cea032de4a581bea56e Mon Sep 17 00:00:00 2001 From: YJ Shi Date: Tue, 19 Jul 2022 16:39:41 -0700 Subject: [PATCH 02/16] fix lint --- python/tvm/meta_schedule/cost_model/xgb_model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tvm/meta_schedule/cost_model/xgb_model.py b/python/tvm/meta_schedule/cost_model/xgb_model.py index 6954047f5d3f..43c2d90a67d0 100644 --- a/python/tvm/meta_schedule/cost_model/xgb_model.py +++ b/python/tvm/meta_schedule/cost_model/xgb_model.py @@ -22,7 +22,7 @@ import tempfile from collections import OrderedDict from itertools import chain as itertools_chain -from typing import TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple +from typing import TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, Optional, Tuple import numpy as np # type: ignore @@ -36,7 +36,7 @@ from .metric import max_curve try: - from xgboost.callback import TrainingCallback + from xgboost.callback import TrainingCallback # type: ignore except ImportError: class TrainingCallback: @@ -791,7 +791,7 @@ def __init__( fevals: List[Callable], evals: List[Tuple["xgb.DMatrix", str]], focused_metric: str = "tr-p-rmse", - cvfolds: Sequence["xgb.training.CVPack"] = None, + cvfolds: List["xgb.training.CVPack"] = None, ): self.early_stopping_rounds = early_stopping_rounds self.verbose_eval = verbose_eval From 8d7903ba1d796e3fb78fbd247a53cd9d4d13c6b7 Mon Sep 17 00:00:00 2001 From: YJ Shi Date: Tue, 19 Jul 2022 17:43:04 -0700 Subject: [PATCH 03/16] fix ci --- python/tvm/meta_schedule/cost_model/xgb_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/meta_schedule/cost_model/xgb_model.py b/python/tvm/meta_schedule/cost_model/xgb_model.py index 43c2d90a67d0..0904141ded9b 100644 --- a/python/tvm/meta_schedule/cost_model/xgb_model.py +++ b/python/tvm/meta_schedule/cost_model/xgb_model.py @@ -39,7 +39,7 @@ from xgboost.callback import TrainingCallback # type: ignore except ImportError: - class TrainingCallback: + class TrainingCallback: # type: ignore pass From 94d8958f4867210502634f307fb9025a53cdfadc Mon Sep 17 00:00:00 2001 From: YJ Shi Date: Wed, 20 Jul 2022 23:06:29 -0700 Subject: [PATCH 04/16] fix lint --- python/tvm/meta_schedule/cost_model/xgb_model.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/tvm/meta_schedule/cost_model/xgb_model.py b/python/tvm/meta_schedule/cost_model/xgb_model.py index 0904141ded9b..8cec77a85736 100644 --- a/python/tvm/meta_schedule/cost_model/xgb_model.py +++ b/python/tvm/meta_schedule/cost_model/xgb_model.py @@ -774,7 +774,7 @@ class XGBoostCallback(TrainingCallback): """Base class for XGBoost callbacks.""" def __call__(self, env: "xgb.core.CallbackEnv"): - """Compatibility with xgboost<1.3""" + # Compatibility with xgboost < 1.3 return self.after_iteration(env.model, env.iteration, env.evaluation_result_list) def after_iteration(self, model: "xgb.Booster", epoch: int, evals_log: Dict): @@ -805,6 +805,7 @@ def __init__( self.aggregated_cv = None def init(self, model: "xgb.Booster"): + """Internal function for intialization""" booster: "xgb.Booster" = model self.state["best_iteration"] = 0 self.state["best_score"] = float("inf") @@ -820,10 +821,12 @@ def init(self, model: "xgb.Booster"): booster.set_attr(best_score=str(self.state["best_score"])) def after_iteration(self, model: "xgb.Booster", epoch: int, evals_log: Dict): + """Internal function for after_iteration""" + # pylint:disable = import-outside-toplevel try: from xgboost.callback import _fmt_metric # type: ignore except ImportError: - """Compatibility with xgboost>=1.6""" + # Compatibility with xgboost >= 1.6 def _fmt_metric(value, show_stdv=True): if len(value) == 2: @@ -834,6 +837,7 @@ def _fmt_metric(value, show_stdv=True): return f"{value[0]}:{value[1]:.5f}" raise ValueError("wrong metric value", value) + import xgboost as xgb from xgboost import rabit # type: ignore try: From 09fd5892da2a451a93b149ddb08b724bc9fc140c Mon Sep 17 00:00:00 2001 From: YJ Shi Date: Tue, 26 Jul 2022 18:35:35 -0700 Subject: [PATCH 05/16] add unit test --- .../tvm/meta_schedule/cost_model/__init__.py | 2 +- .../tvm/meta_schedule/cost_model/xgb_model.py | 124 ++---------------- .../unittest/test_meta_schedule_cost_model.py | 92 ++++++++++++- 3 files changed, 103 insertions(+), 115 deletions(-) diff --git a/python/tvm/meta_schedule/cost_model/__init__.py b/python/tvm/meta_schedule/cost_model/__init__.py index 8fc6f04ac955..47b418d5db12 100644 --- a/python/tvm/meta_schedule/cost_model/__init__.py +++ b/python/tvm/meta_schedule/cost_model/__init__.py @@ -19,4 +19,4 @@ """ from .cost_model import CostModel, PyCostModel from .random_model import RandomModel -from .xgb_model import XGBModel +from .xgb_model import XGBModel, XGBoostCustomCallback, PackSum diff --git a/python/tvm/meta_schedule/cost_model/xgb_model.py b/python/tvm/meta_schedule/cost_model/xgb_model.py index 8cec77a85736..50f54f5da088 100644 --- a/python/tvm/meta_schedule/cost_model/xgb_model.py +++ b/python/tvm/meta_schedule/cost_model/xgb_model.py @@ -35,15 +35,15 @@ from ..utils import cpu_count, derived_object, shash2hex from .metric import max_curve -try: - from xgboost.callback import TrainingCallback # type: ignore -except ImportError: - class TrainingCallback: # type: ignore - pass +if TYPE_CHECKING: + try: + from xgboost.callback import TrainingCallback # type: ignore + except ImportError: + class TrainingCallback: # type: ignore + pass -if TYPE_CHECKING: import xgboost as xgb # type: ignore from ..tune_context import TuneContext @@ -674,114 +674,8 @@ def init(env: "xgb.core.CallbackEnv"): booster.set_attr(best_iteration=str(state["best_iteration"])) booster.set_attr(best_score=str(state["best_score"])) - def callback(env: "xgb.core.CallbackEnv"): - # pylint:disable = import-outside-toplevel - import xgboost as xgb - from xgboost.callback import _fmt_metric # type: ignore - from xgboost.core import EarlyStopException # type: ignore - - try: - from xgboost.training import aggcv # type: ignore - except ImportError: - from xgboost.callback import _aggcv as aggcv # type: ignore - # pylint:enable = import-outside-toplevel - - if not state: - init(env) - booster: xgb.Booster = env.model - iteration: int = env.iteration - cvfolds: List[xgb.training.CVPack] = env.cvfolds - ##### Evaluation ##### - # `eval_result` is a list of (key, score) - eval_result: List[Tuple[str, float]] = [] - if cvfolds is None: - eval_result = list( - itertools_chain.from_iterable( - [ - (key, float(value)) - for key, value in map( - lambda x: x.split(":"), - booster.eval_set( - evals=evals, - iteration=iteration, - feval=feval, - ).split()[1:], - ) - ] - for feval in fevals - ) - ) - else: - eval_result = list( - itertools_chain.from_iterable( - [ - (key, score) - for key, score, _std in aggcv( - fold.eval( - iteration=iteration, - feval=feval, - ) - for fold in cvfolds - ) - ] - for feval in fevals - ) - ) - eval_result = list(eval_result) - eval_result.sort(key=sort_key) - - ##### Print eval result ##### - if verbose_eval and iteration % verbose_eval == 0: - info = [] - for key, score in eval_result: - if "null" not in key: - info.append(f"{key}: {score:.6f}") - logger.debug("XGB iter %3d: %s", iteration, "\t".join(info)) - - ##### Choose score and do early stopping ##### - score = None - for key, _score in eval_result: - if key == focused_metric: - score = _score - break - assert score is not None - best_score = state["best_score"] - best_iteration = state["best_iteration"] - if score < best_score: - tab = "\t" # to work with f-string - msg = f"[{env.iteration}] {tab.join([_fmt_metric(x) for x in eval_result])}" - state["best_msg"] = msg - state["best_score"] = score - state["best_iteration"] = env.iteration - # save the property to attributes, so they will occur in checkpoint. - if env.model is not None: - env.model.set_attr( - best_score=str(state["best_score"]), - best_iteration=str(state["best_iteration"]), - best_msg=state["best_msg"], - ) - elif env.iteration - best_iteration >= early_stopping_rounds: - best_msg = state["best_msg"] - if verbose_eval and env.rank == 0: - logger.debug("XGB stopped. Best iteration: %s ", best_msg) - raise EarlyStopException(best_iteration) - - return callback - - -class XGBoostCallback(TrainingCallback): - """Base class for XGBoost callbacks.""" - - def __call__(self, env: "xgb.core.CallbackEnv"): - # Compatibility with xgboost < 1.3 - return self.after_iteration(env.model, env.iteration, env.evaluation_result_list) - - def after_iteration(self, model: "xgb.Booster", epoch: int, evals_log: Dict): - raise NotImplementedError - - -class XGBoostCustomCallback(XGBoostCallback): +class XGBoostCustomCallback(TrainingCallback): """Custom callback class for xgboost to support multiple custom evaluation functions""" def __init__( @@ -804,6 +698,10 @@ def __init__( if cvfolds is not None: self.aggregated_cv = None + def __call__(self, env: "xgb.core.CallbackEnv"): + # Compatibility with xgboost < 1.3 + return self.after_iteration(env.model, env.iteration, env.evaluation_result_list) + def init(self, model: "xgb.Booster"): """Internal function for intialization""" booster: "xgb.Booster" = model diff --git a/tests/python/unittest/test_meta_schedule_cost_model.py b/tests/python/unittest/test_meta_schedule_cost_model.py index d1d558181324..91c84bbdb88b 100644 --- a/tests/python/unittest/test_meta_schedule_cost_model.py +++ b/tests/python/unittest/test_meta_schedule_cost_model.py @@ -26,7 +26,13 @@ import pytest import tvm import tvm.testing -from tvm.meta_schedule.cost_model import PyCostModel, RandomModel, XGBModel +from tvm.meta_schedule.cost_model import ( + PyCostModel, + RandomModel, + XGBModel, + XGBoostCustomCallback, + PackSum, +) from tvm.meta_schedule.feature_extractor import RandomFeatureExtractor from tvm.meta_schedule.runner import RunnerResult from tvm.meta_schedule.search_strategy import MeasureCandidate @@ -228,5 +234,89 @@ def test_meta_schedule_xgb_model_reupdate(): model.predict(TuneContext(), [_dummy_candidate() for i in range(predict_sample_count)]) +def test_meta_schedule_xgb_model_callback(): + import xgboost as xgb + from itertools import chain as itertools_chain + from functools import partial + + extractor = RandomFeatureExtractor() + model = XGBModel(extractor=extractor, num_warmup_samples=10) + update_sample_count = 20 + predict_sample_count = 30 + + model.update( + TuneContext(), + [_dummy_candidate() for i in range(update_sample_count)], + [_dummy_result() for i in range(update_sample_count)], + ) + model.predict(TuneContext(), [_dummy_candidate() for i in range(predict_sample_count)]) + with tempfile.NamedTemporaryFile() as path: + # Backup and train on new TrainingCallBack api + random_state = model.extractor.random_state # save feature extractor's random state + + model.save(path.name) + + old_booster = model.booster + xs = [ + x.numpy().astype("float32") + for x in extractor.extract_from( + TuneContext(), + [_dummy_candidate() for i in range(predict_sample_count)], + ) + ] + d_test = PackSum(xs=xs, ys=None) + pred1 = old_booster.predict(d_test.dmatrix) + + # Load and train on deprecated TrainingCallBack api + model.extractor.random_state = random_state # load feature extractor's random state + model.load(path.name) + d_train = PackSum( + xs=list(itertools_chain.from_iterable([g.features for g in model.data.values()])), + ys=np.concatenate( + [g.min_cost / g.costs for g in model.data.values()], + axis=0, + ), + ) + + def obj(ys_pred: np.ndarray, d_train1: "xgb.DMatrix"): # type: ignore # pylint: disable = unused-argument + return d_train.obj_square_error(ys_pred) + + def rmse(ys_pred: np.ndarray, d_train1: "xgb.DMatrix"): # type: ignore # pylint: disable = unused-argument + return d_train.rmse(ys_pred) + + def avg_peak_score(ys_pred: np.ndarray, d_train1: "xgb.DMatrix"): # type: ignore # pylint: disable = unused-argument + return d_train.average_peak_score(ys_pred, model.average_peak_n) + + new_booster = xgb.train( + model.config.to_dict(), + d_train.dmatrix, + num_boost_round=10000, + obj=obj, + callbacks=[ + partial( + XGBoostCustomCallback( + early_stopping_rounds=model.early_stopping_rounds, + verbose_eval=model.verbose_eval, + fevals=[rmse, avg_peak_score], + evals=[(d_train.dmatrix, "tr")], + cvfolds=None, + ) + ) + ], + ) + + xs = [ + x.numpy().astype("float32") + for x in extractor.extract_from( + TuneContext(), + [_dummy_candidate() for i in range(predict_sample_count)], + ) + ] + d_test = PackSum(xs=xs, ys=None) + pred2 = new_booster.predict(d_test.dmatrix) + + assert np.allclose(pred1, pred2, rtol=1e-3, atol=1e-3) + + if __name__ == "__main__": tvm.testing.main() From 8471a9aa3a9249ce08ed22664cf3cdc284e7a5f3 Mon Sep 17 00:00:00 2001 From: YJ Shi Date: Tue, 26 Jul 2022 20:21:12 -0700 Subject: [PATCH 06/16] remote unused code --- .../tvm/meta_schedule/cost_model/xgb_model.py | 30 ------------------- 1 file changed, 30 deletions(-) diff --git a/python/tvm/meta_schedule/cost_model/xgb_model.py b/python/tvm/meta_schedule/cost_model/xgb_model.py index 50f54f5da088..e704d26620a6 100644 --- a/python/tvm/meta_schedule/cost_model/xgb_model.py +++ b/python/tvm/meta_schedule/cost_model/xgb_model.py @@ -645,36 +645,6 @@ def average_peak_score(ys_pred: np.ndarray): return eval_result -def custom_callback( - early_stopping_rounds: int, - verbose_eval: int, - fevals: List[Callable], - evals: List[Tuple["xgb.DMatrix", str]], - focused_metric: str = "tr-p-rmse", -): - """Callback function for xgboost to support multiple custom evaluation functions""" - sort_key = make_metric_sorter(focused_metric=focused_metric) - - state: Dict[str, Any] = {} - - def init(env: "xgb.core.CallbackEnv"): - """Internal function""" - booster: "xgb.Booster" = env.model - - state["best_iteration"] = 0 - state["best_score"] = float("inf") - if booster is None: - assert env.cvfolds is not None - return - if booster.attr("best_score") is not None: - state["best_score"] = float(booster.attr("best_score")) - state["best_iteration"] = int(booster.attr("best_iteration")) - state["best_msg"] = booster.attr("best_msg") - else: - booster.set_attr(best_iteration=str(state["best_iteration"])) - booster.set_attr(best_score=str(state["best_score"])) - - class XGBoostCustomCallback(TrainingCallback): """Custom callback class for xgboost to support multiple custom evaluation functions""" From 4c563dc5ec2bb5a89d33cf02b55f44332f296543 Mon Sep 17 00:00:00 2001 From: YJ Shi Date: Tue, 26 Jul 2022 20:41:23 -0700 Subject: [PATCH 07/16] fix lint --- python/tvm/meta_schedule/cost_model/xgb_model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/tvm/meta_schedule/cost_model/xgb_model.py b/python/tvm/meta_schedule/cost_model/xgb_model.py index e704d26620a6..840bd8f4d63c 100644 --- a/python/tvm/meta_schedule/cost_model/xgb_model.py +++ b/python/tvm/meta_schedule/cost_model/xgb_model.py @@ -688,7 +688,9 @@ def init(self, model: "xgb.Booster"): booster.set_attr(best_iteration=str(self.state["best_iteration"])) booster.set_attr(best_score=str(self.state["best_score"])) - def after_iteration(self, model: "xgb.Booster", epoch: int, evals_log: Dict): + def after_iteration( + self, model: "xgb.Booster", epoch: int, evals_log: Dict + ): # pylint: disable = unused-argument """Internal function for after_iteration""" # pylint:disable = import-outside-toplevel try: From af6dda4478b3630d8a96f4f02a3bcf13c5aed70b Mon Sep 17 00:00:00 2001 From: YJ Shi Date: Wed, 27 Jul 2022 11:55:45 -0700 Subject: [PATCH 08/16] add decorator --- python/tvm/meta_schedule/cost_model/xgb_model.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/python/tvm/meta_schedule/cost_model/xgb_model.py b/python/tvm/meta_schedule/cost_model/xgb_model.py index 840bd8f4d63c..47b4dec3d681 100644 --- a/python/tvm/meta_schedule/cost_model/xgb_model.py +++ b/python/tvm/meta_schedule/cost_model/xgb_model.py @@ -36,7 +36,8 @@ from .metric import max_curve -if TYPE_CHECKING: +def optional_xgboost_callback(XGBoostCustomCallback): + # pylint:disable = import-outside-toplevel try: from xgboost.callback import TrainingCallback # type: ignore except ImportError: @@ -44,6 +45,14 @@ class TrainingCallback: # type: ignore pass + class OptXGBoostCustomCallback(XGBoostCustomCallback, TrainingCallback): + pass + + return OptXGBoostCustomCallback + + +if TYPE_CHECKING: + import xgboost as xgb # type: ignore from ..tune_context import TuneContext @@ -645,7 +654,8 @@ def average_peak_score(ys_pred: np.ndarray): return eval_result -class XGBoostCustomCallback(TrainingCallback): +@optional_xgboost_callback +class XGBoostCustomCallback: """Custom callback class for xgboost to support multiple custom evaluation functions""" def __init__( From 56e4b012c1f52926fd0f2b470b0dad508a22d3e4 Mon Sep 17 00:00:00 2001 From: YJ Shi Date: Wed, 27 Jul 2022 12:02:30 -0700 Subject: [PATCH 09/16] address comment --- python/tvm/meta_schedule/cost_model/xgb_model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/tvm/meta_schedule/cost_model/xgb_model.py b/python/tvm/meta_schedule/cost_model/xgb_model.py index 47b4dec3d681..b451d14de27d 100644 --- a/python/tvm/meta_schedule/cost_model/xgb_model.py +++ b/python/tvm/meta_schedule/cost_model/xgb_model.py @@ -724,6 +724,8 @@ def _fmt_metric(value, show_stdv=True): from xgboost.training import aggcv # type: ignore except ImportError: from xgboost.callback import _aggcv as aggcv # type: ignore + + # pylint:enable = import-outside-toplevel if not self.state: self.init(model) booster: xgb.Booster = model From 046463157bb2354b80ab52b574924b1c2ed1fc52 Mon Sep 17 00:00:00 2001 From: YJ Shi Date: Wed, 27 Jul 2022 13:06:51 -0700 Subject: [PATCH 10/16] fix lint --- python/tvm/meta_schedule/cost_model/xgb_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/meta_schedule/cost_model/xgb_model.py b/python/tvm/meta_schedule/cost_model/xgb_model.py index b451d14de27d..433db0ca40d9 100644 --- a/python/tvm/meta_schedule/cost_model/xgb_model.py +++ b/python/tvm/meta_schedule/cost_model/xgb_model.py @@ -724,7 +724,7 @@ def _fmt_metric(value, show_stdv=True): from xgboost.training import aggcv # type: ignore except ImportError: from xgboost.callback import _aggcv as aggcv # type: ignore - + # pylint:enable = import-outside-toplevel if not self.state: self.init(model) From e9e17021533a1f6922998e4c5f6619299fde9966 Mon Sep 17 00:00:00 2001 From: YJ Shi Date: Wed, 27 Jul 2022 13:27:08 -0700 Subject: [PATCH 11/16] address comments --- python/tvm/meta_schedule/cost_model/xgb_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tvm/meta_schedule/cost_model/xgb_model.py b/python/tvm/meta_schedule/cost_model/xgb_model.py index 433db0ca40d9..c1e01ef9a57b 100644 --- a/python/tvm/meta_schedule/cost_model/xgb_model.py +++ b/python/tvm/meta_schedule/cost_model/xgb_model.py @@ -40,6 +40,7 @@ def optional_xgboost_callback(XGBoostCustomCallback): # pylint:disable = import-outside-toplevel try: from xgboost.callback import TrainingCallback # type: ignore + # pylint:enable = import-outside-toplevel except ImportError: class TrainingCallback: # type: ignore From c2e1c7eaf84dc99d977e6c0774dac36de7d9467a Mon Sep 17 00:00:00 2001 From: YJ Shi Date: Wed, 27 Jul 2022 13:40:22 -0700 Subject: [PATCH 12/16] fix mypy --- python/tvm/meta_schedule/cost_model/xgb_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/meta_schedule/cost_model/xgb_model.py b/python/tvm/meta_schedule/cost_model/xgb_model.py index c1e01ef9a57b..0df785ba25ad 100644 --- a/python/tvm/meta_schedule/cost_model/xgb_model.py +++ b/python/tvm/meta_schedule/cost_model/xgb_model.py @@ -46,7 +46,7 @@ def optional_xgboost_callback(XGBoostCustomCallback): class TrainingCallback: # type: ignore pass - class OptXGBoostCustomCallback(XGBoostCustomCallback, TrainingCallback): + class OptXGBoostCustomCallback(XGBoostCustomCallback, TrainingCallback): # type: ignore pass return OptXGBoostCustomCallback From 775731baf47ac619ad5854f871208f6d7175c991 Mon Sep 17 00:00:00 2001 From: YJ Shi Date: Wed, 27 Jul 2022 14:36:18 -0700 Subject: [PATCH 13/16] fix lint --- python/tvm/meta_schedule/cost_model/xgb_model.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/tvm/meta_schedule/cost_model/xgb_model.py b/python/tvm/meta_schedule/cost_model/xgb_model.py index 0df785ba25ad..57ac4ba9233e 100644 --- a/python/tvm/meta_schedule/cost_model/xgb_model.py +++ b/python/tvm/meta_schedule/cost_model/xgb_model.py @@ -36,7 +36,8 @@ from .metric import max_curve -def optional_xgboost_callback(XGBoostCustomCallback): +def optional_xgboost_callback(cls): + """Decorator for importing TraningCallback from xgboost""" # pylint:disable = import-outside-toplevel try: from xgboost.callback import TrainingCallback # type: ignore @@ -46,7 +47,7 @@ def optional_xgboost_callback(XGBoostCustomCallback): class TrainingCallback: # type: ignore pass - class OptXGBoostCustomCallback(XGBoostCustomCallback, TrainingCallback): # type: ignore + class OptXGBoostCustomCallback(cls, TrainingCallback): # type: ignore pass return OptXGBoostCustomCallback From 08663ae3cd1693927046e56c26e3d8409b3a5a6f Mon Sep 17 00:00:00 2001 From: YJ Shi Date: Fri, 23 Sep 2022 16:09:11 -0700 Subject: [PATCH 14/16] remove unused comments --- python/tvm/meta_schedule/cost_model/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/meta_schedule/cost_model/__init__.py b/python/tvm/meta_schedule/cost_model/__init__.py index 47b418d5db12..8fc6f04ac955 100644 --- a/python/tvm/meta_schedule/cost_model/__init__.py +++ b/python/tvm/meta_schedule/cost_model/__init__.py @@ -19,4 +19,4 @@ """ from .cost_model import CostModel, PyCostModel from .random_model import RandomModel -from .xgb_model import XGBModel, XGBoostCustomCallback, PackSum +from .xgb_model import XGBModel From a08be9764f96707195bee9936ead36f1d67a0b91 Mon Sep 17 00:00:00 2001 From: YJ Shi Date: Fri, 23 Sep 2022 16:40:59 -0700 Subject: [PATCH 15/16] address comments --- .../tvm/meta_schedule/cost_model/xgb_model.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/python/tvm/meta_schedule/cost_model/xgb_model.py b/python/tvm/meta_schedule/cost_model/xgb_model.py index 57ac4ba9233e..1171e081b90a 100644 --- a/python/tvm/meta_schedule/cost_model/xgb_model.py +++ b/python/tvm/meta_schedule/cost_model/xgb_model.py @@ -592,19 +592,20 @@ def rmse(ys_pred: np.ndarray, d_train: "xgb.DMatrix"): # type: ignore # pylint: def avg_peak_score(ys_pred: np.ndarray, d_train: "xgb.DMatrix"): # type: ignore # pylint: disable = unused-argument return self.d_train.average_peak_score(ys_pred, self.average_peak_n) - xgb_custom_callback = XGBoostCustomCallback( - early_stopping_rounds=self.early_stopping_rounds, - verbose_eval=self.verbose_eval, - fevals=[rmse, avg_peak_score], - evals=[(self.d_train.dmatrix, "tr")], - cvfolds=None, - ) self.booster = xgb.train( self.config.to_dict(), self.d_train.dmatrix, num_boost_round=10000, obj=obj, - callbacks=[xgb_custom_callback], + callbacks=[ + XGBoostCustomCallback( + early_stopping_rounds=self.early_stopping_rounds, + verbose_eval=self.verbose_eval, + fevals=[rmse, avg_peak_score], + evals=[(self.d_train.dmatrix, "tr")], + cvfolds=None, + ) + ], ) del self.d_train From bc83709d908e631cbeefd3c2b03babc1a90080b8 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Sun, 25 Sep 2022 13:18:34 -0700 Subject: [PATCH 16/16] Fix xgboost unit test import. --- tests/python/unittest/test_meta_schedule_cost_model.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/tests/python/unittest/test_meta_schedule_cost_model.py b/tests/python/unittest/test_meta_schedule_cost_model.py index 91c84bbdb88b..94b7bce246f4 100644 --- a/tests/python/unittest/test_meta_schedule_cost_model.py +++ b/tests/python/unittest/test_meta_schedule_cost_model.py @@ -26,13 +26,8 @@ import pytest import tvm import tvm.testing -from tvm.meta_schedule.cost_model import ( - PyCostModel, - RandomModel, - XGBModel, - XGBoostCustomCallback, - PackSum, -) +from tvm.meta_schedule.cost_model import PyCostModel, RandomModel, XGBModel +from tvm.meta_schedule.cost_model.xgb_model import XGBoostCustomCallback, PackSum from tvm.meta_schedule.feature_extractor import RandomFeatureExtractor from tvm.meta_schedule.runner import RunnerResult from tvm.meta_schedule.search_strategy import MeasureCandidate