Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions python/tvm/contrib/hexagon/meta_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,9 @@ def _worker_func(hexagon_launcher, evaluator_config, alloc_repeat, artifact_path
return costs


def get_hexagon_local_builder(pass_context: tvm.transform.PassContext = None):
def get_hexagon_local_builder(
pass_context: tvm.transform.PassContext = None, max_workers: Optional[int] = None
):
"""Return Hexagon-compatible Builder for meta schedule."""

def export_func(mod):
Expand All @@ -143,13 +145,19 @@ def default_build_with_context(
return tvm_build(mod, target=target)

if pass_context is not None:
return LocalBuilder(f_build=default_build_with_context, f_export=export_func)
return LocalBuilder(
f_build=default_build_with_context, f_export=export_func, max_workers=max_workers
)
else:
return LocalBuilder(f_export=export_func)
return LocalBuilder(f_export=export_func, max_workers=max_workers)


def get_hexagon_rpc_runner(
hexagon_launcher: HexagonLauncherRPC, number=3, repeat=1, min_repeat_ms=100
hexagon_launcher: HexagonLauncherRPC,
number=3,
repeat=1,
min_repeat_ms=100,
max_workers: Optional[int] = None,
):
"""Return Hexagon-compatible RPC Runner for meta schedule.

Expand Down Expand Up @@ -177,7 +185,4 @@ def get_hexagon_rpc_runner(
enable_cpu_cache_flush=False,
)

return HexagonRPCRunner(
hexagon_launcher,
evaluator_config,
)
return HexagonRPCRunner(hexagon_launcher, evaluator_config, max_workers=max_workers)
4 changes: 2 additions & 2 deletions python/tvm/contrib/torch/as_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def tune(
space: ms.SpaceGenerator.SpaceGeneratorType = "post-order-apply",
strategy: ms.SearchStrategy.SearchStrategyType = "replay-trace",
task_name: str = "main",
num_threads: Union[Literal["physical", "logical"], int] = "physical",
num_tuning_cores: Union[Literal["physical", "logical"], int] = "physical",
seed: Optional[int] = None,
) -> None:
"""
Expand Down Expand Up @@ -100,7 +100,7 @@ def tune(
space=space,
strategy=strategy,
task_name=task_name,
num_threads=num_threads,
num_tuning_cores=num_tuning_cores,
seed=seed,
)
sch = ms.tir_integration.compile_tir(database, self.ir_module, target)
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/meta_schedule/cost_model/cost_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,11 @@ def create(

if kind == "xgb":
return XGBModel(*args, **kwargs) # type: ignore

if "num_tuning_cores" in kwargs:
# num_tuning_cores is only relevant for XGBModel.
kwargs.pop("num_tuning_cores")

if kind == "random":
return RandomModel(*args, **kwargs) # type: ignore
if kind == "mlp":
Expand Down
7 changes: 6 additions & 1 deletion python/tvm/meta_schedule/cost_model/xgb_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ def __init__(
verbose_eval: int = 25,
average_peak_n: int = 32,
adaptive_training: bool = True,
num_tuning_cores: Optional[int] = None,
):
super().__init__()
if not isinstance(extractor, FeatureExtractor):
Expand All @@ -342,7 +343,11 @@ def __init__(
# model-related
if config.nthread is None:
# use physical core number
config = config._replace(nthread=cpu_count(logical=False))
if num_tuning_cores is None:
config = config._replace(nthread=cpu_count(logical=False))
else:
config = config._replace(nthread=num_tuning_cores)

self.config = config
# behavior of randomness
self.num_warmup_samples = num_warmup_samples
Expand Down
12 changes: 8 additions & 4 deletions python/tvm/meta_schedule/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def extracted_tasks_to_tune_contexts(
work_dir: str,
space: SpaceGenerator.SpaceGeneratorType = "post-order-apply",
strategy: SearchStrategy.SearchStrategyType = "evolutionary",
num_threads: Union[Literal["physical", "logical"], int] = "physical",
num_tuning_cores: Union[Literal["physical", "logical"], int] = "physical",
seed: Optional[int] = None,
) -> Tuple[List[TuneContext], List[float]]:
"""Convert ExtractedTask to TuneContext.
Expand All @@ -195,8 +195,8 @@ def extracted_tasks_to_tune_contexts(
The space generator to use.
strategy : SearchStrategy.SearchStrategyType
The search strategy to use.
num_threads : Union[Literal["physical", "logical"], int]
The number of threads to use in multi-threaded search algorithm.
num_tuning_cores : Union[Literal["physical", "logical"], int]
The number of CPU cores to use during tuning.
seed : Optional[int]
The random seed to use.

Expand All @@ -223,7 +223,7 @@ def extracted_tasks_to_tune_contexts(
task_name=task.task_name,
logger=logger,
rand_state=rand_state,
num_threads=num_threads,
num_threads=num_tuning_cores,
).clone()
)
task_weights.append(task.weight)
Expand All @@ -249,6 +249,7 @@ def tune_relay(
strategy: SearchStrategy.SearchStrategyType = "evolutionary",
seed: Optional[int] = None,
module_equality: str = "structural",
num_tuning_cores: Union[Literal["physical", "logical"], int] = "physical",
) -> Database:
"""Tune a Relay program.

Expand Down Expand Up @@ -296,6 +297,8 @@ def tune_relay(
given module. The "ignore-ndarray" varint is used for the extracted
blocks or in case no anchor block is found.
For the definition of the anchor block, see tir/analysis/analysis.py.
num_tuning_cores : Union[Literal["physical", "logical"], int]
The number of CPU cores to use during tuning.

Returns
-------
Expand All @@ -308,6 +311,7 @@ def tune_relay(
space=space,
strategy=strategy,
seed=seed,
num_tuning_cores=num_tuning_cores,
)
return tune_tasks(
tasks=tasks,
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/meta_schedule/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,8 @@ def create( # pylint: disable=keyword-arg-before-vararg
from . import LocalRunner, RPCRunner # pylint: disable=import-outside-toplevel

if kind == "local":
if "max_workers" in kwargs:
kwargs.pop("max_workers")
return LocalRunner(*args, **kwargs) # type: ignore
elif kind == "rpc":
return RPCRunner(*args, **kwargs) # type: ignore
Expand Down
8 changes: 4 additions & 4 deletions python/tvm/meta_schedule/tir_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def tune_tir(
space: SpaceGenerator.SpaceGeneratorType = "post-order-apply",
strategy: SearchStrategy.SearchStrategyType = "evolutionary",
task_name: str = "main",
num_threads: Union[Literal["physical", "logical"], int] = "physical",
num_tuning_cores: Union[Literal["physical", "logical"], int] = "physical",
seed: Optional[int] = None,
) -> Database:
"""Tune a TIR function.
Expand Down Expand Up @@ -89,8 +89,8 @@ def tune_tir(
The search strategy.
task_name : str
The name of the task.
num_threads : Union[Literal["physical", "logical"], int]
The number of threads to use.
num_tuning_cores : Union[Literal["physical", "logical"], int]
The number of CPU cores to use during tuning.
seed : Optional[int]
The seed for the random number generator.

Expand All @@ -111,7 +111,7 @@ def tune_tir(
task_name=task_name,
logger=logger,
rand_state=seed,
num_threads=num_threads,
num_threads=num_tuning_cores,
).clone()
],
task_weights=[1.0],
Expand Down
12 changes: 9 additions & 3 deletions python/tvm/meta_schedule/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,22 +86,28 @@ def tune_tasks(
database : Database
The database with all tuning records
"""
if len(tasks) == 0:
raise ValueError("No tasks to tune.")

if len(tasks) != len(task_weights):
raise ValueError(
f"Length of tasks ({len(tasks)}) and task_weights ({len(task_weights)}) do not match."
)

num_cores = tasks[0].num_threads

if max_trials_per_task is None:
max_trials_per_task = max_trials_global
if not isinstance(builder, Builder):
builder = Builder.create(builder)
builder = Builder.create(builder, max_workers=num_cores)
if not isinstance(runner, Runner):
runner = Runner.create(runner)
runner = Runner.create(runner, max_workers=num_cores)
if database == "json":
database = Database.create(database, work_dir=work_dir, module_equality=module_equality)
elif not isinstance(database, Database):
database = Database.create(database, module_equality=module_equality)
if not isinstance(cost_model, CostModel):
cost_model = CostModel.create(cost_model)
cost_model = CostModel.create(cost_model, num_tuning_cores=num_cores)
if isinstance(measure_callbacks, MeasureCallback):
measure_callbacks = [measure_callbacks]
elif measure_callbacks == "default":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
get_hexagon_rpc_runner,
)
from tvm.meta_schedule import postproc, schedule_rule
from tvm.meta_schedule.utils import cpu_count
from tvm.tir.schedule import BlockRV, Schedule
from tvm.tir.schedule.analysis import has_block
from tvm.tir.tensor_intrin.hexagon import (
Expand All @@ -44,10 +45,24 @@
from ..infrastructure import get_hexagon_target

MODEL_JSON = "resnet50_int8.json"
MODEL_PARAMS = "resnet50_int8.params"
EXECUTOR = relay.backend.Executor("graph", {"link-params": True})
TARGET_LLVM = tvm.target.Target("llvm")
TARGET_HEXAGON = get_hexagon_target("v68")
MODEL_PARAMS = "resnet50_int8.params"


def load_model():
"""Load renset50 model."""
if not os.path.exists(MODEL_JSON):
pytest.skip(msg="Run python export_models.py first.")

with open(MODEL_JSON, "r") as file:
mod = tvm.ir.load_json(file.read())

with open(MODEL_PARAMS, "rb") as file:
params = relay.load_param_dict(file.read())

return mod, params


def tune_vrmpy_auto_tensorize(mod, params, hexagon_launcher):
Expand Down Expand Up @@ -110,6 +125,8 @@ def tune_vrmpy_auto_tensorize(mod, params, hexagon_launcher):
# task extraction and relay.build(...).
mod = mod.with_attr("executor", EXECUTOR)

num_cores = cpu_count(logical=False)

with tempfile.TemporaryDirectory() as work_dir:
database = ms.relay_integration.tune_relay(
mod=mod,
Expand All @@ -125,8 +142,8 @@ def tune_vrmpy_auto_tensorize(mod, params, hexagon_launcher):
# num_trials_per_iter=32,
# max_trials_per_task=128,
# strategy="evolutionary",
builder=get_hexagon_local_builder(),
runner=get_hexagon_rpc_runner(hexagon_launcher, number=20),
builder=get_hexagon_local_builder(max_workers=num_cores),
runner=get_hexagon_rpc_runner(hexagon_launcher, number=20, max_workers=num_cores),
space=ms.space_generator.PostOrderApply(
sch_rules=sch_rules,
postprocs=postprocs,
Expand All @@ -137,6 +154,7 @@ def tune_vrmpy_auto_tensorize(mod, params, hexagon_launcher):
# It reduces the number of conv2d tuning tasks in the int8 resnet50 model
# from 36 to 23, with negligible performance difference.
module_equality="anchor-block",
num_tuning_cores=num_cores,
)
return ms.relay_integration.compile_relay(
database=database,
Expand All @@ -156,11 +174,8 @@ def test_resnet50(hexagon_launcher):
if not os.path.exists(MODEL_JSON):
pytest.skip(msg="Run python export_models.py first.")

with open(MODEL_JSON, "r") as file:
mod = tvm.ir.load_json(file.read())
mod, params = load_model()

with open(MODEL_PARAMS, "rb") as file:
params = relay.load_param_dict(file.read())
inp = np.random.randn(1, 3, 224, 224).astype("float32")
input_name = "image"

Expand Down Expand Up @@ -231,20 +246,6 @@ def evaluate_mod(hexagon_launcher, hexagon_lowered, llvm_lowered, input_name, in
np.testing.assert_allclose(ref_result, output, atol=1e-4, rtol=1e-5)


def load_model():
"""Load renset50 model."""
if not os.path.exists(MODEL_JSON):
pytest.skip(msg="Run python export_models.py first.")

with open(MODEL_JSON, "r") as file:
mod = tvm.ir.load_json(file.read())

with open(MODEL_PARAMS, "rb") as file:
params = relay.load_param_dict(file.read())

return mod, params


def _schedule_packed_8x8x32_conv2d():
"""Manually schedule a conv2d block, created from TE compute op via CreatePrimFunc,
using 8x8x32 packed layout.
Expand Down
7 changes: 5 additions & 2 deletions tests/python/contrib/test_hexagon/test_meta_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,11 @@ def test_builder_runner(hexagon_launcher):

mod = MatmulModule

builder = get_hexagon_local_builder()
runner = get_hexagon_rpc_runner(hexagon_launcher, number=1, repeat=1, min_repeat_ms=0)
max_workers = 4
builder = get_hexagon_local_builder(max_workers=max_workers)
runner = get_hexagon_rpc_runner(
hexagon_launcher, number=1, repeat=1, min_repeat_ms=0, max_workers=max_workers
)

(builder_result,) = builder.build([BuilderInput(mod, get_hexagon_target("v68"))])
assert builder_result.artifact_path is not None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,7 @@ def _test_anchor_tuning(target):
max_trials_global=4,
strategy="replay-trace",
module_equality=module_equality,
num_tuning_cores=4,
)
lib = ms.relay_integration.compile_relay(database, mod, target, params)

Expand Down