Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 23 additions & 56 deletions python/tvm/meta_schedule/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
# under the License.
"""MetaSchedule-Relay integration"""
from contextlib import contextmanager
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, Set
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union

# isort: off
from typing_extensions import Literal

# isort: on
import numpy as np # type: ignore

from tvm import nd
from tvm._ffi import get_global_func
from tvm.ir import IRModule, transform
Expand Down Expand Up @@ -71,15 +71,8 @@ def _normalize_params(
mod: IRModule,
target: Union[Target, str],
params: Optional[Dict[str, NDArray]],
pass_config: Mapping[str, Any],
executor: Optional["relay.backend.Executor"],
) -> Tuple[
IRModule,
Target,
Dict[str, NDArray],
Dict[str, Any],
Optional["relay.backend.Executor"],
]:
) -> Tuple[IRModule, Target, Dict[str, NDArray], Optional["relay.backend.Executor"],]:
from tvm import relay # pylint: disable=import-outside-toplevel

if isinstance(mod, relay.Function):
Expand All @@ -102,25 +95,16 @@ def _normalize_params(
else:
executor = mod.get_attr("executor")

pass_config = dict(pass_config)
return mod, target, relay_params, pass_config, executor
return mod, target, relay_params, executor


def extract_tasks(
mod: IRModule,
target: Union[Target, str],
params: Optional[Dict[str, NDArray]],
*,
opt_level: int = 3,
pass_config: Mapping[str, Any] = MappingProxyType(
{
"relay.backend.use_meta_schedule": True,
"relay.backend.tir_converter": "default",
}
),
executor: Optional["relay.backend.Executor"] = None,
module_equality: str = "structural",
disabled_pass: Optional[Union[List[str], Set[str], Tuple[str]]] = None,
) -> List[ExtractedTask]:
"""Extract tuning tasks from a relay program.

Expand All @@ -132,10 +116,6 @@ def extract_tasks(
The compilation target
params : Optional[Dict[str, tvm.runtime.NDArray]]
The associated parameters of the program
opt_level : int
The optimization level of the compilation
pass_config : Mapping[str, Any]
The pass configuration
executor : Optional[relay.backend.Executor]
The executor to use
module_equality : Optional[str]
Expand All @@ -148,8 +128,6 @@ def extract_tasks(
given module. The "ignore-ndarray" varint is used for the extracted
blocks or in case no anchor block is found.
For the definition of the anchor block, see tir/analysis/analysis.py.
disabled_pass : Optional[Union[List[str], Set[str], Tuple[str]]]
The list of disabled passes

Returns
-------
Expand All @@ -160,21 +138,25 @@ def extract_tasks(
from tvm import autotvm

# pylint: enable=import-outside-toplevel
mod, target, params, pass_config, _ = _normalize_params(
mod, target, params, pass_config, executor
)
mod, target, params, _ = _normalize_params(mod, target, params, executor)
if target.kind.name != "cuda" and isinstance(
autotvm.DispatchContext.current, autotvm.FallbackContext
):
tophub_context = autotvm.tophub.context(target)
else:
tophub_context = autotvm.utils.EmptyContext()
pass_ctx = transform.PassContext.current()
pass_config = dict(pass_ctx.config)
pass_config.setdefault("relay.backend.use_meta_schedule", True)
pass_config.setdefault("relay.backend.tir_converter", "default")
with Profiler.timeit("TaskExtraction"):
with target, _autotvm_silencer(), tophub_context:
with transform.PassContext(
opt_level=opt_level,
opt_level=pass_ctx.opt_level,
required_pass=pass_ctx.required_pass,
disabled_pass=pass_ctx.disabled_pass,
instruments=pass_ctx.instruments,
config=pass_config,
disabled_pass=disabled_pass,
):
return list(_extract_task(mod, target, params, module_equality))

Expand Down Expand Up @@ -254,7 +236,6 @@ def tune_relay(
seed: Optional[int] = None,
module_equality: str = "structural",
num_tuning_cores: Union[Literal["physical", "logical"], int] = "physical",
disabled_pass: Optional[Union[List[str], Set[str], Tuple[str]]] = None,
) -> Database:
"""Tune a Relay program.

Expand Down Expand Up @@ -304,18 +285,14 @@ def tune_relay(
For the definition of the anchor block, see tir/analysis/analysis.py.
num_tuning_cores : Union[Literal["physical", "logical"], int]
The number of CPU cores to use during tuning.
disabled_pass : Optional[Union[List[str], Set[str], Tuple[str]]]
The list of disabled passes during tasks extraction

Returns
-------
database : Database
The database that contains the tuning records
"""
tasks, task_weights = extracted_tasks_to_tune_contexts(
extracted_tasks=extract_tasks(
mod, target, params, module_equality=module_equality, disabled_pass=disabled_pass
),
extracted_tasks=extract_tasks(mod, target, params, module_equality=module_equality),
work_dir=work_dir,
space=space,
strategy=strategy,
Expand Down Expand Up @@ -346,15 +323,7 @@ def compile_relay(
params: Optional[Dict[str, NDArray]],
*,
backend: Literal["graph", "vm"] = "graph",
opt_level: int = 3,
pass_config: Mapping[str, Any] = MappingProxyType(
{
"relay.backend.use_meta_schedule": True,
"relay.backend.tir_converter": "default",
}
),
executor: Optional["relay.backend.Executor"] = None,
disabled_pass: Optional[Union[List[str], Set[str], Tuple[str]]] = None,
):
"""Compile a relay program with a MetaSchedule database.

Expand All @@ -372,14 +341,8 @@ def compile_relay(
The backend to use. Builtin backends:
- "graph"
- "vm"
opt_level : int
The optimization level of the compilation
pass_config : Mapping[str, Any]
The pass configuration
executor : Optional[relay.backend.Executor]
The executor to use in relay.build. It is not supported by RelayVM.
disabled_pass : Optional[Union[List[str], Set[str], Tuple[str]]]
The list of disabled passes

Returns
-------
Expand All @@ -390,16 +353,20 @@ def compile_relay(
from tvm import relay

# pylint: enable=import-outside-toplevel
mod, target, params, pass_config, executor = _normalize_params(
mod, target, params, pass_config, executor
)
mod, target, params, executor = _normalize_params(mod, target, params, executor)
pass_ctx = transform.PassContext.current()
pass_config = dict(pass_ctx.config)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To prevent breaking user's script after this change, how about setting a default ctx with opt_level = 3 if there is no current ctx?

Copy link
Contributor Author

@tkonolige tkonolige Jan 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now there is no way to determine whether or not if the current pass context is the default or not. It might be possible to add though.

It is not consistent with the rest of TVM to default the opt_level to 3 in meta schedule. Everywhere else defaults to 2. This does mean that we might break some user scripts.

pass_config.setdefault("relay.backend.use_meta_schedule_dispatch", True)
pass_config.setdefault("relay.backend.use_meta_schedule", True)
pass_config.setdefault("relay.backend.tir_converter", "default")
with Profiler.timeit("PostTuningCompilation"):
with target, _autotvm_silencer(), database:
with transform.PassContext(
opt_level=opt_level,
opt_level=pass_ctx.opt_level,
required_pass=pass_ctx.required_pass,
disabled_pass=pass_ctx.disabled_pass,
instruments=pass_ctx.instruments,
config=pass_config,
disabled_pass=disabled_pass,
):
if backend == "graph":
return relay.build(mod, target=target, params=params, executor=executor)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def schedule_conv2d_for_tune(sch: Schedule):
else None
)

with tempfile.TemporaryDirectory() as work_dir:
with tempfile.TemporaryDirectory() as work_dir, pass_context:
database = ms.relay_integration.tune_relay(
mod=mod,
target=TARGET_HEXAGON,
Expand Down Expand Up @@ -382,15 +382,11 @@ def schedule_conv2d_for_tune(sch: Schedule):
module_equality="ignore-ndarray",
)

# Add default options so that it still uses the base config.
pass_config["relay.backend.use_meta_schedule"] = True
pass_config["relay.backend.tir_converter"] = "default"
return ms.relay_integration.compile_relay(
database=database,
mod=mod,
target=TARGET_HEXAGON,
params=params,
pass_config=pass_config,
)


Expand Down