diff --git a/python/tvm/meta_schedule/cost_model/xgb_model.py b/python/tvm/meta_schedule/cost_model/xgb_model.py index 6b6b7a2dc1ed..aaee58fc94c8 100644 --- a/python/tvm/meta_schedule/cost_model/xgb_model.py +++ b/python/tvm/meta_schedule/cost_model/xgb_model.py @@ -755,7 +755,12 @@ def _fmt_metric(value, show_stdv=True): raise ValueError("wrong metric value", value) import xgboost as xgb - from xgboost import rabit # type: ignore + + # make it compatible with xgboost<1.7 + try: + from xgboost import rabit as collective # type: ignore + except ImportError: + from xgboost import collective # type: ignore try: from xgboost.training import aggcv # type: ignore @@ -841,7 +846,7 @@ def _fmt_metric(value, show_stdv=True): elif epoch - best_iteration >= self.early_stopping_rounds: best_msg = self.state["best_msg"] - if self.verbose_eval and rabit.get_rank() == 0: + if self.verbose_eval and collective.get_rank() == 0: logger.debug("XGB stopped. Best iteration: %s ", best_msg) # instead of raising EarlyStopException, returning True to end the training return True