diff --git a/python/tvm/contrib/hexagon/meta_schedule.py b/python/tvm/contrib/hexagon/meta_schedule.py index dcc7d232d8c4..6e1541e498a9 100644 --- a/python/tvm/contrib/hexagon/meta_schedule.py +++ b/python/tvm/contrib/hexagon/meta_schedule.py @@ -128,7 +128,9 @@ def _worker_func(hexagon_launcher, evaluator_config, alloc_repeat, artifact_path return costs -def get_hexagon_local_builder(pass_context: tvm.transform.PassContext = None): +def get_hexagon_local_builder( + pass_context: tvm.transform.PassContext = None, max_workers: Optional[int] = None +): """Return Hexagon-compatible Builder for meta schedule.""" def export_func(mod): @@ -143,13 +145,19 @@ def default_build_with_context( return tvm_build(mod, target=target) if pass_context is not None: - return LocalBuilder(f_build=default_build_with_context, f_export=export_func) + return LocalBuilder( + f_build=default_build_with_context, f_export=export_func, max_workers=max_workers + ) else: - return LocalBuilder(f_export=export_func) + return LocalBuilder(f_export=export_func, max_workers=max_workers) def get_hexagon_rpc_runner( - hexagon_launcher: HexagonLauncherRPC, number=3, repeat=1, min_repeat_ms=100 + hexagon_launcher: HexagonLauncherRPC, + number=3, + repeat=1, + min_repeat_ms=100, + max_workers: Optional[int] = None, ): """Return Hexagon-compatible RPC Runner for meta schedule. @@ -177,7 +185,4 @@ def get_hexagon_rpc_runner( enable_cpu_cache_flush=False, ) - return HexagonRPCRunner( - hexagon_launcher, - evaluator_config, - ) + return HexagonRPCRunner(hexagon_launcher, evaluator_config, max_workers=max_workers) diff --git a/python/tvm/contrib/torch/as_torch.py b/python/tvm/contrib/torch/as_torch.py index 918ce3ff3b6a..c4ca88adf738 100644 --- a/python/tvm/contrib/torch/as_torch.py +++ b/python/tvm/contrib/torch/as_torch.py @@ -67,7 +67,7 @@ def tune( space: ms.SpaceGenerator.SpaceGeneratorType = "post-order-apply", strategy: ms.SearchStrategy.SearchStrategyType = "replay-trace", task_name: str = "main", - num_threads: Union[Literal["physical", "logical"], int] = "physical", + num_tuning_cores: Union[Literal["physical", "logical"], int] = "physical", seed: Optional[int] = None, ) -> None: """ @@ -100,7 +100,7 @@ def tune( space=space, strategy=strategy, task_name=task_name, - num_threads=num_threads, + num_tuning_cores=num_tuning_cores, seed=seed, ) sch = ms.tir_integration.compile_tir(database, self.ir_module, target) diff --git a/python/tvm/meta_schedule/cost_model/cost_model.py b/python/tvm/meta_schedule/cost_model/cost_model.py index f139fcc4e4b3..c0f6ea5fb9e1 100644 --- a/python/tvm/meta_schedule/cost_model/cost_model.py +++ b/python/tvm/meta_schedule/cost_model/cost_model.py @@ -126,6 +126,11 @@ def create( if kind == "xgb": return XGBModel(*args, **kwargs) # type: ignore + + if "num_tuning_cores" in kwargs: + # num_tuning_cores is only relevant for XGBModel. + kwargs.pop("num_tuning_cores") + if kind == "random": return RandomModel(*args, **kwargs) # type: ignore if kind == "mlp": diff --git a/python/tvm/meta_schedule/cost_model/xgb_model.py b/python/tvm/meta_schedule/cost_model/xgb_model.py index 0a2786c6abe0..901e18ce3fa5 100644 --- a/python/tvm/meta_schedule/cost_model/xgb_model.py +++ b/python/tvm/meta_schedule/cost_model/xgb_model.py @@ -333,6 +333,7 @@ def __init__( verbose_eval: int = 25, average_peak_n: int = 32, adaptive_training: bool = True, + num_tuning_cores: Optional[int] = None, ): super().__init__() if not isinstance(extractor, FeatureExtractor): @@ -342,7 +343,11 @@ def __init__( # model-related if config.nthread is None: # use physical core number - config = config._replace(nthread=cpu_count(logical=False)) + if num_tuning_cores is None: + config = config._replace(nthread=cpu_count(logical=False)) + else: + config = config._replace(nthread=num_tuning_cores) + self.config = config # behavior of randomness self.num_warmup_samples = num_warmup_samples diff --git a/python/tvm/meta_schedule/relay_integration.py b/python/tvm/meta_schedule/relay_integration.py index df76684d2d42..0b8705aafea9 100644 --- a/python/tvm/meta_schedule/relay_integration.py +++ b/python/tvm/meta_schedule/relay_integration.py @@ -180,7 +180,7 @@ def extracted_tasks_to_tune_contexts( work_dir: str, space: SpaceGenerator.SpaceGeneratorType = "post-order-apply", strategy: SearchStrategy.SearchStrategyType = "evolutionary", - num_threads: Union[Literal["physical", "logical"], int] = "physical", + num_tuning_cores: Union[Literal["physical", "logical"], int] = "physical", seed: Optional[int] = None, ) -> Tuple[List[TuneContext], List[float]]: """Convert ExtractedTask to TuneContext. @@ -195,8 +195,8 @@ def extracted_tasks_to_tune_contexts( The space generator to use. strategy : SearchStrategy.SearchStrategyType The search strategy to use. - num_threads : Union[Literal["physical", "logical"], int] - The number of threads to use in multi-threaded search algorithm. + num_tuning_cores : Union[Literal["physical", "logical"], int] + The number of CPU cores to use during tuning. seed : Optional[int] The random seed to use. @@ -223,7 +223,7 @@ def extracted_tasks_to_tune_contexts( task_name=task.task_name, logger=logger, rand_state=rand_state, - num_threads=num_threads, + num_threads=num_tuning_cores, ).clone() ) task_weights.append(task.weight) @@ -249,6 +249,7 @@ def tune_relay( strategy: SearchStrategy.SearchStrategyType = "evolutionary", seed: Optional[int] = None, module_equality: str = "structural", + num_tuning_cores: Union[Literal["physical", "logical"], int] = "physical", ) -> Database: """Tune a Relay program. @@ -296,6 +297,8 @@ def tune_relay( given module. The "ignore-ndarray" varint is used for the extracted blocks or in case no anchor block is found. For the definition of the anchor block, see tir/analysis/analysis.py. + num_tuning_cores : Union[Literal["physical", "logical"], int] + The number of CPU cores to use during tuning. Returns ------- @@ -308,6 +311,7 @@ def tune_relay( space=space, strategy=strategy, seed=seed, + num_tuning_cores=num_tuning_cores, ) return tune_tasks( tasks=tasks, diff --git a/python/tvm/meta_schedule/runner/runner.py b/python/tvm/meta_schedule/runner/runner.py index 1753d8b4abf9..1a8f78414e91 100644 --- a/python/tvm/meta_schedule/runner/runner.py +++ b/python/tvm/meta_schedule/runner/runner.py @@ -194,6 +194,8 @@ def create( # pylint: disable=keyword-arg-before-vararg from . import LocalRunner, RPCRunner # pylint: disable=import-outside-toplevel if kind == "local": + if "max_workers" in kwargs: + kwargs.pop("max_workers") return LocalRunner(*args, **kwargs) # type: ignore elif kind == "rpc": return RPCRunner(*args, **kwargs) # type: ignore diff --git a/python/tvm/meta_schedule/tir_integration.py b/python/tvm/meta_schedule/tir_integration.py index 975987ebcb67..f3d505c28b0e 100644 --- a/python/tvm/meta_schedule/tir_integration.py +++ b/python/tvm/meta_schedule/tir_integration.py @@ -54,7 +54,7 @@ def tune_tir( space: SpaceGenerator.SpaceGeneratorType = "post-order-apply", strategy: SearchStrategy.SearchStrategyType = "evolutionary", task_name: str = "main", - num_threads: Union[Literal["physical", "logical"], int] = "physical", + num_tuning_cores: Union[Literal["physical", "logical"], int] = "physical", seed: Optional[int] = None, ) -> Database: """Tune a TIR function. @@ -89,8 +89,8 @@ def tune_tir( The search strategy. task_name : str The name of the task. - num_threads : Union[Literal["physical", "logical"], int] - The number of threads to use. + num_tuning_cores : Union[Literal["physical", "logical"], int] + The number of CPU cores to use during tuning. seed : Optional[int] The seed for the random number generator. @@ -111,7 +111,7 @@ def tune_tir( task_name=task_name, logger=logger, rand_state=seed, - num_threads=num_threads, + num_threads=num_tuning_cores, ).clone() ], task_weights=[1.0], diff --git a/python/tvm/meta_schedule/tune.py b/python/tvm/meta_schedule/tune.py index a69c8f126272..0c4035844c71 100644 --- a/python/tvm/meta_schedule/tune.py +++ b/python/tvm/meta_schedule/tune.py @@ -86,22 +86,28 @@ def tune_tasks( database : Database The database with all tuning records """ + if len(tasks) == 0: + raise ValueError("No tasks to tune.") + if len(tasks) != len(task_weights): raise ValueError( f"Length of tasks ({len(tasks)}) and task_weights ({len(task_weights)}) do not match." ) + + num_cores = tasks[0].num_threads + if max_trials_per_task is None: max_trials_per_task = max_trials_global if not isinstance(builder, Builder): - builder = Builder.create(builder) + builder = Builder.create(builder, max_workers=num_cores) if not isinstance(runner, Runner): - runner = Runner.create(runner) + runner = Runner.create(runner, max_workers=num_cores) if database == "json": database = Database.create(database, work_dir=work_dir, module_equality=module_equality) elif not isinstance(database, Database): database = Database.create(database, module_equality=module_equality) if not isinstance(cost_model, CostModel): - cost_model = CostModel.create(cost_model) + cost_model = CostModel.create(cost_model, num_tuning_cores=num_cores) if isinstance(measure_callbacks, MeasureCallback): measure_callbacks = [measure_callbacks] elif measure_callbacks == "default": diff --git a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py index e15b0a4e7ddb..1e01cb28a749 100644 --- a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py +++ b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py @@ -33,6 +33,7 @@ get_hexagon_rpc_runner, ) from tvm.meta_schedule import postproc, schedule_rule +from tvm.meta_schedule.utils import cpu_count from tvm.tir.schedule import BlockRV, Schedule from tvm.tir.schedule.analysis import has_block from tvm.tir.tensor_intrin.hexagon import ( @@ -44,10 +45,24 @@ from ..infrastructure import get_hexagon_target MODEL_JSON = "resnet50_int8.json" +MODEL_PARAMS = "resnet50_int8.params" EXECUTOR = relay.backend.Executor("graph", {"link-params": True}) TARGET_LLVM = tvm.target.Target("llvm") TARGET_HEXAGON = get_hexagon_target("v68") -MODEL_PARAMS = "resnet50_int8.params" + + +def load_model(): + """Load renset50 model.""" + if not os.path.exists(MODEL_JSON): + pytest.skip(msg="Run python export_models.py first.") + + with open(MODEL_JSON, "r") as file: + mod = tvm.ir.load_json(file.read()) + + with open(MODEL_PARAMS, "rb") as file: + params = relay.load_param_dict(file.read()) + + return mod, params def tune_vrmpy_auto_tensorize(mod, params, hexagon_launcher): @@ -110,6 +125,8 @@ def tune_vrmpy_auto_tensorize(mod, params, hexagon_launcher): # task extraction and relay.build(...). mod = mod.with_attr("executor", EXECUTOR) + num_cores = cpu_count(logical=False) + with tempfile.TemporaryDirectory() as work_dir: database = ms.relay_integration.tune_relay( mod=mod, @@ -125,8 +142,8 @@ def tune_vrmpy_auto_tensorize(mod, params, hexagon_launcher): # num_trials_per_iter=32, # max_trials_per_task=128, # strategy="evolutionary", - builder=get_hexagon_local_builder(), - runner=get_hexagon_rpc_runner(hexagon_launcher, number=20), + builder=get_hexagon_local_builder(max_workers=num_cores), + runner=get_hexagon_rpc_runner(hexagon_launcher, number=20, max_workers=num_cores), space=ms.space_generator.PostOrderApply( sch_rules=sch_rules, postprocs=postprocs, @@ -137,6 +154,7 @@ def tune_vrmpy_auto_tensorize(mod, params, hexagon_launcher): # It reduces the number of conv2d tuning tasks in the int8 resnet50 model # from 36 to 23, with negligible performance difference. module_equality="anchor-block", + num_tuning_cores=num_cores, ) return ms.relay_integration.compile_relay( database=database, @@ -156,11 +174,8 @@ def test_resnet50(hexagon_launcher): if not os.path.exists(MODEL_JSON): pytest.skip(msg="Run python export_models.py first.") - with open(MODEL_JSON, "r") as file: - mod = tvm.ir.load_json(file.read()) + mod, params = load_model() - with open(MODEL_PARAMS, "rb") as file: - params = relay.load_param_dict(file.read()) inp = np.random.randn(1, 3, 224, 224).astype("float32") input_name = "image" @@ -231,20 +246,6 @@ def evaluate_mod(hexagon_launcher, hexagon_lowered, llvm_lowered, input_name, in np.testing.assert_allclose(ref_result, output, atol=1e-4, rtol=1e-5) -def load_model(): - """Load renset50 model.""" - if not os.path.exists(MODEL_JSON): - pytest.skip(msg="Run python export_models.py first.") - - with open(MODEL_JSON, "r") as file: - mod = tvm.ir.load_json(file.read()) - - with open(MODEL_PARAMS, "rb") as file: - params = relay.load_param_dict(file.read()) - - return mod, params - - def _schedule_packed_8x8x32_conv2d(): """Manually schedule a conv2d block, created from TE compute op via CreatePrimFunc, using 8x8x32 packed layout. diff --git a/tests/python/contrib/test_hexagon/test_meta_schedule.py b/tests/python/contrib/test_hexagon/test_meta_schedule.py index a83a3b279a7f..1089f0f03589 100644 --- a/tests/python/contrib/test_hexagon/test_meta_schedule.py +++ b/tests/python/contrib/test_hexagon/test_meta_schedule.py @@ -73,8 +73,11 @@ def test_builder_runner(hexagon_launcher): mod = MatmulModule - builder = get_hexagon_local_builder() - runner = get_hexagon_rpc_runner(hexagon_launcher, number=1, repeat=1, min_repeat_ms=0) + max_workers = 4 + builder = get_hexagon_local_builder(max_workers=max_workers) + runner = get_hexagon_rpc_runner( + hexagon_launcher, number=1, repeat=1, min_repeat_ms=0, max_workers=max_workers + ) (builder_result,) = builder.build([BuilderInput(mod, get_hexagon_target("v68"))]) assert builder_result.artifact_path is not None diff --git a/tests/python/unittest/test_meta_schedule_relay_integration.py b/tests/python/unittest/test_meta_schedule_relay_integration.py index 021db0f86ad2..062da0b00ca3 100644 --- a/tests/python/unittest/test_meta_schedule_relay_integration.py +++ b/tests/python/unittest/test_meta_schedule_relay_integration.py @@ -742,6 +742,7 @@ def _test_anchor_tuning(target): max_trials_global=4, strategy="replay-trace", module_equality=module_equality, + num_tuning_cores=4, ) lib = ms.relay_integration.compile_relay(database, mod, target, params)