From e0816ba8b6b56ad7dad214ef96fdb5b3e7813c79 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lufang=20CHEN=20=E9=99=88=E6=A9=B9=E6=96=B9?= Date: Wed, 21 Jun 2023 06:32:33 +0000 Subject: [PATCH 1/4] support xgb set tree method --- python/tvm/meta_schedule/cost_model/xgb_model.py | 12 ++++++++++++ python/tvm/meta_schedule/tir_integration.py | 2 ++ python/tvm/meta_schedule/tune.py | 3 ++- 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/python/tvm/meta_schedule/cost_model/xgb_model.py b/python/tvm/meta_schedule/cost_model/xgb_model.py index fde2f2f60529..fa3e8a1a429e 100644 --- a/python/tvm/meta_schedule/cost_model/xgb_model.py +++ b/python/tvm/meta_schedule/cost_model/xgb_model.py @@ -202,6 +202,8 @@ def average_peak_score( class XGBConfig(NamedTuple): """XGBoost model configuration + Reference: https://xgboost.readthedocs.io/en/stable/parameter.html + Parameters ---------- max_depth : int @@ -217,6 +219,8 @@ class XGBConfig(NamedTuple): nthread : Optional[int], The number of threads to use. Default is None, which means to use physical number of cores. + tree_method str : + The tree construction algorithm used in XGBoost. """ max_depth: int = 10 @@ -225,8 +229,11 @@ class XGBConfig(NamedTuple): eta: float = 0.2 seed: int = 43 nthread: Optional[int] = None + tree_method: str = "auto" def to_dict(self): + """Convert to dict""" + return { "max_depth": self.max_depth, "gamma": self.gamma, @@ -234,6 +241,7 @@ def to_dict(self): "eta": self.eta, "seed": self.seed, "nthread": self.nthread, + "tree_method": self.tree_method, } @@ -334,6 +342,7 @@ def __init__( average_peak_n: int = 32, adaptive_training: bool = True, num_tuning_cores: Optional[int] = None, + tree_method: Optional[str] = None, ): super().__init__() if not isinstance(extractor, FeatureExtractor): @@ -348,6 +357,9 @@ def __init__( else: config = config._replace(nthread=num_tuning_cores) + if tree_method is not None: + config._replace(tree_method=tree_method) + self.config = config # behavior of randomness self.num_warmup_samples = num_warmup_samples diff --git a/python/tvm/meta_schedule/tir_integration.py b/python/tvm/meta_schedule/tir_integration.py index 5f6f82bf148b..299537b62eb6 100644 --- a/python/tvm/meta_schedule/tir_integration.py +++ b/python/tvm/meta_schedule/tir_integration.py @@ -58,6 +58,7 @@ def tune_tir( # pylint: disable=too-many-locals seed: Optional[int] = None, module_equality: str = "structural", special_space: Optional[Mapping[str, SpaceGenerator.SpaceGeneratorType]] = None, + **kwargs, ) -> Database: """Tune a TIR function or an IRModule of TIR functions. @@ -154,6 +155,7 @@ def tune_tir( # pylint: disable=too-many-locals measure_callbacks=measure_callbacks, task_scheduler=task_scheduler, module_equality=module_equality, + **kwargs, ) diff --git a/python/tvm/meta_schedule/tune.py b/python/tvm/meta_schedule/tune.py index 132f446a5252..a65bd70a2481 100644 --- a/python/tvm/meta_schedule/tune.py +++ b/python/tvm/meta_schedule/tune.py @@ -41,6 +41,7 @@ def tune_tasks( measure_callbacks: MeasureCallback.CallbackListType = "default", task_scheduler: TaskScheduler.TaskSchedulerType = "gradient", module_equality: str = "structural", + **kwargs, ) -> Database: """Tune a list of tasks. Using a task scheduler. @@ -108,7 +109,7 @@ def tune_tasks( elif not isinstance(database, Database): database = Database.create(database, module_equality=module_equality) if not isinstance(cost_model, CostModel): - cost_model = CostModel.create(cost_model, num_tuning_cores=num_cores) + cost_model = CostModel.create(cost_model, num_tuning_cores=num_cores, **kwargs) if isinstance(measure_callbacks, MeasureCallback): measure_callbacks = [measure_callbacks] elif measure_callbacks == "default": From dca90693301062674f991454edfdd9f2aa7eeb4c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lufang=20CHEN=20=E9=99=88=E6=A9=B9=E6=96=B9?= Date: Sat, 22 Jul 2023 06:27:42 +0000 Subject: [PATCH 2/4] fix lint --- python/tvm/meta_schedule/cost_model/xgb_model.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/tvm/meta_schedule/cost_model/xgb_model.py b/python/tvm/meta_schedule/cost_model/xgb_model.py index fa3e8a1a429e..4afc6834c772 100644 --- a/python/tvm/meta_schedule/cost_model/xgb_model.py +++ b/python/tvm/meta_schedule/cost_model/xgb_model.py @@ -21,6 +21,8 @@ from itertools import chain as itertools_chain from typing import TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, Optional, Tuple +from typing_extensions import Literal + import numpy as np # type: ignore from ...contrib.tar import tar, untar @@ -219,7 +221,7 @@ class XGBConfig(NamedTuple): nthread : Optional[int], The number of threads to use. Default is None, which means to use physical number of cores. - tree_method str : + tree_method : Literal["auto", "exact", "approx", "hist", "gpu_hist"] The tree construction algorithm used in XGBoost. """ @@ -229,7 +231,7 @@ class XGBConfig(NamedTuple): eta: float = 0.2 seed: int = 43 nthread: Optional[int] = None - tree_method: str = "auto" + tree_method: Literal["auto", "exact", "approx", "hist", "gpu_hist"] = "auto" def to_dict(self): """Convert to dict""" From 45755a181a60e69f975b011cc2b8473a80c90877 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lufang=20CHEN=20=E9=99=88=E6=A9=B9=E6=96=B9?= Date: Sun, 23 Jul 2023 09:00:12 +0000 Subject: [PATCH 3/4] fix lint --- python/tvm/meta_schedule/cost_model/xgb_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/meta_schedule/cost_model/xgb_model.py b/python/tvm/meta_schedule/cost_model/xgb_model.py index 4afc6834c772..6b6b7a2dc1ed 100644 --- a/python/tvm/meta_schedule/cost_model/xgb_model.py +++ b/python/tvm/meta_schedule/cost_model/xgb_model.py @@ -344,7 +344,7 @@ def __init__( average_peak_n: int = 32, adaptive_training: bool = True, num_tuning_cores: Optional[int] = None, - tree_method: Optional[str] = None, + tree_method: Optional[Literal["auto", "exact", "approx", "hist", "gpu_hist"]] = None, ): super().__init__() if not isinstance(extractor, FeatureExtractor): From 2245123a457336128a8fc7578a78220f95f5fe31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lufang=20CHEN=20=E9=99=88=E6=A9=B9=E6=96=B9?= Date: Wed, 26 Jul 2023 06:18:29 +0000 Subject: [PATCH 4/4] fix --- python/tvm/meta_schedule/cost_model/cost_model.py | 9 ++++++--- python/tvm/meta_schedule/tir_integration.py | 2 -- python/tvm/meta_schedule/tune.py | 3 +-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/python/tvm/meta_schedule/cost_model/cost_model.py b/python/tvm/meta_schedule/cost_model/cost_model.py index c0f6ea5fb9e1..541154d4cc59 100644 --- a/python/tvm/meta_schedule/cost_model/cost_model.py +++ b/python/tvm/meta_schedule/cost_model/cost_model.py @@ -127,9 +127,12 @@ def create( if kind == "xgb": return XGBModel(*args, **kwargs) # type: ignore - if "num_tuning_cores" in kwargs: - # num_tuning_cores is only relevant for XGBModel. - kwargs.pop("num_tuning_cores") + # params only relevant to XGBModel + _xgb_params = ["num_tuning_cores", "tree_method"] + + for param in _xgb_params: + if param in kwargs: + kwargs.pop(param) if kind == "random": return RandomModel(*args, **kwargs) # type: ignore diff --git a/python/tvm/meta_schedule/tir_integration.py b/python/tvm/meta_schedule/tir_integration.py index 299537b62eb6..5f6f82bf148b 100644 --- a/python/tvm/meta_schedule/tir_integration.py +++ b/python/tvm/meta_schedule/tir_integration.py @@ -58,7 +58,6 @@ def tune_tir( # pylint: disable=too-many-locals seed: Optional[int] = None, module_equality: str = "structural", special_space: Optional[Mapping[str, SpaceGenerator.SpaceGeneratorType]] = None, - **kwargs, ) -> Database: """Tune a TIR function or an IRModule of TIR functions. @@ -155,7 +154,6 @@ def tune_tir( # pylint: disable=too-many-locals measure_callbacks=measure_callbacks, task_scheduler=task_scheduler, module_equality=module_equality, - **kwargs, ) diff --git a/python/tvm/meta_schedule/tune.py b/python/tvm/meta_schedule/tune.py index a65bd70a2481..887941ada0d2 100644 --- a/python/tvm/meta_schedule/tune.py +++ b/python/tvm/meta_schedule/tune.py @@ -41,7 +41,6 @@ def tune_tasks( measure_callbacks: MeasureCallback.CallbackListType = "default", task_scheduler: TaskScheduler.TaskSchedulerType = "gradient", module_equality: str = "structural", - **kwargs, ) -> Database: """Tune a list of tasks. Using a task scheduler. @@ -109,7 +108,7 @@ def tune_tasks( elif not isinstance(database, Database): database = Database.create(database, module_equality=module_equality) if not isinstance(cost_model, CostModel): - cost_model = CostModel.create(cost_model, num_tuning_cores=num_cores, **kwargs) + cost_model = CostModel.create(cost_model, num_tuning_cores=num_cores, tree_method="auto") if isinstance(measure_callbacks, MeasureCallback): measure_callbacks = [measure_callbacks] elif measure_callbacks == "default":