Skip to content
Merged
2 changes: 0 additions & 2 deletions python/tvm/auto_scheduler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

from . import compute_dag
from . import dispatcher
from . import env
from . import feature
from . import loop_state
from . import measure
Expand All @@ -36,7 +35,6 @@
from .compute_dag import ComputeDAG
from .cost_model import RandomModel, XGBModel
from .dispatcher import DispatchContext, ApplyHistoryBest
from .env import enable_relay_integration, is_relay_integration_enabled
from .measure import (
MeasureInput,
MeasureResult,
Expand Down
30 changes: 20 additions & 10 deletions python/tvm/auto_scheduler/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class DispatchContext(object):
def __init__(self):
self._old_ctx = DispatchContext.current

def query(self, target, workload_key):
def query(self, target, workload_key, has_complex_op, dag):
"""
Query the context to get the specific config for a workload.
If cannot find the result inside this context, this function will query it
Expand All @@ -56,6 +56,10 @@ def query(self, target, workload_key):
The current target
workload_key : str
The workload key
has_complex_op: bool
Whether this workload has at least one complex op.
dag: ComputeDAG
The ComputeDAG of the workload.

Returns
-------
Expand All @@ -64,7 +68,7 @@ def query(self, target, workload_key):
"""
ret = self._query_inside(target, workload_key)
if ret is None:
ret = self._old_ctx.query(target, workload_key)
ret = self._old_ctx.query(target, workload_key, has_complex_op, dag)
return ret

def update(self, target, workload_key, state):
Expand Down Expand Up @@ -220,11 +224,11 @@ def _query_inside(self, target, workload_key):

def update(self, target, workload_key, state):
model = target.model
key = (model, workload)
key = (model, workload_key)
self._best_user_defined[key] = state

for k in target.keys:
key = (k, workload)
key = (k, workload_key)
self._best_user_defined[key] = state


Expand All @@ -237,21 +241,27 @@ class FallbackContext(DispatchContext):
def __init__(self):
super(FallbackContext, self).__init__()
self.memory = {}
self.silent = False

# Verbose level:
# 0: Completely silent.
# 1: Warning the missing configs for querying complex tasks.
# 2: Warning the missing configs for querying all tasks.
self.verbose = 1

# a set to prevent print duplicated message
self.messages = set()

def query(self, target, workload_key):
def query(self, target, workload_key, has_complex_op, dag):
key = (str(target), workload_key)
if key in self.memory:
return self.memory[key]

if not self.silent:
if self.verbose == 2 or (has_complex_op and self.verbose == 1):
msg = (
"Cannot find tuned schedules for target=%s, workload_key=%s. "
"A fallback schedule is used, "
"which may bring great performance regression." % (target, workload_key)
"Cannot find tuned schedules for target=%s, workload_key=%s, compute:\n%s"
"A fallback TOPI schedule is used, "
"which may bring great performance regression or even compilation failure."
% (target, workload_key, dag)
)
if msg not in self.messages:
self.messages.add(msg)
Expand Down
56 changes: 0 additions & 56 deletions python/tvm/auto_scheduler/env.py

This file was deleted.

59 changes: 37 additions & 22 deletions python/tvm/auto_scheduler/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import threading

import tvm
from tvm import te, transform
from tvm import autotvm, te, transform
from tvm.te.tensor import ComputeOp, PlaceholderOp
from .compute_dag import ComputeDAG
from .dispatcher import DispatchContext
Expand All @@ -34,18 +34,26 @@


def call_all_topi_funcs(mod, params, target):
"""Call all TOPI compute + schedule to extract tasks in a relay program"""
"""Call all TOPI compute to extract auto_scheduler tasks in a Relay program"""
# pylint: disable=import-outside-toplevel
from tvm import relay
from tvm.relay.backend import graph_runtime_codegen

with transform.PassContext(opt_level=3):
# Turn off AutoTVM config not found warnings
old_autotvm_silent = autotvm.GLOBAL_SCOPE.silent
autotvm.GLOBAL_SCOPE.silent = True

with transform.PassContext(opt_level=3, config={"relay.backend.use_auto_scheduler": True}):
opt_mod, _ = relay.optimize(mod, target, params)
grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target)
grc.codegen(opt_mod["main"])

autotvm.GLOBAL_SCOPE.silent = old_autotvm_silent


def extract_tasks(mod, params, target, target_host=None, hardware_params=None):
def extract_tasks(
mod, params, target, target_host=None, hardware_params=None, include_simple_tasks=False
):
"""Extract tuning tasks from a relay program.

Parameters
Expand All @@ -60,6 +68,8 @@ def extract_tasks(mod, params, target, target_host=None, hardware_params=None):
The host compilation target
hardware_params : Optional[HardwareParams]
Hardware parameters used for the search tasks
include_simple_tasks: bool
Whether to extract simple tasks that do not include complicated ops.

Returns
-------
Expand All @@ -77,7 +87,9 @@ def extract_tasks(mod, params, target, target_host=None, hardware_params=None):
target_host = tvm.target.Target(target_host)

# Run the compiler to collect all TOPI calls during compilation.
env = TracingEnvironment(TracingMode.EXTRACT_TASK)
env = TracingEnvironment(
TracingMode.EXTRACT_TASK if include_simple_tasks else TracingMode.EXTRACT_COMPLEX_TASK_ONLY
)
with env:
# Wrap build call in a new thread to avoid the conflict
# between python's multiprocessing and tvm's thread pool
Expand Down Expand Up @@ -109,7 +121,8 @@ class TracingMode:
"""Two modes for tracing"""

EXTRACT_TASK = 0 # trace all topi calls to extract tasks
PREPARE_LAYOUT_REWRITE = 1 # trace topi calls to prepare layout rewrite
EXTRACT_COMPLEX_TASK_ONLY = 1 # same as EXTRACT_TASK but ignore the task without complex ops
PREPARE_LAYOUT_REWRITE = 2 # trace topi calls to prepare layout rewrite


class TracingEnvironment:
Expand Down Expand Up @@ -181,11 +194,8 @@ def traverse(t):
return inputs + list(outs), has_layout_free


# The suffix of implementations that use the auto-scheduler in the OpStrategy.
auto_schedule_impl_suffix = ".auto_scheduler"


def auto_schedule_topi(outs):
@tvm._ffi.register_func("auto_scheduler.relay_integration.auto_schedule_topi_compute")
def auto_schedule_topi(outs, has_complex_op):
"""Use auto-scheduler to schedule any topi compute function.

Note: This is used internally for relay integration. Do
Expand All @@ -195,35 +205,40 @@ def auto_schedule_topi(outs):
----------
outs: List[Tensor]
The output tensors of topi compute functions
has_complex_op: bool
Whether the topi compute function includes at least one complex op.

Returns
-------
sch: te.Schedule
A topi schedule function
sch: Optional[te.Schedule]
A tuned schedule or none (if not tuned) in the final build mode;
An initial schdule in the tracing mode.
"""
# pylint: disable=import-outside-toplevel
from tvm import relay

io_tensors, has_layout_free = traverse_to_get_io_tensors(outs)
key = register_workload_tensors(io_tensors)
if key is None: # skip this compute if failed to register the workload
return None

# only enable layout rewrite for cpu backend
enable_layout_rewrite = "cpu" in tvm.target.Target.current().keys

env = TracingEnvironment.current
if env is None: # in the final build mode
state = DispatchContext.current.query(tvm.target.Target.current(), key)
dag = ComputeDAG(io_tensors)
state = DispatchContext.current.query(tvm.target.Target.current(), key, has_complex_op, dag)
if state is None:
if "gpu" in tvm.target.Target.current().keys:
raise RuntimeError("Cannot compile for GPU targets if no valid schedule is found.")
return te.create_schedule([x.op for x in outs])
return None

dag = ComputeDAG(io_tensors)
schedule, _ = dag.apply_steps_from_state(state)
elif env.tracing_mode == TracingMode.EXTRACT_TASK: # in the task extraction mode
engine = relay.backend.compile_engine.get()
ccache_key = engine.get_current_ccache_key()
env.add_workload_key(key, ccache_key)
elif env.tracing_mode in [TracingMode.EXTRACT_TASK, TracingMode.EXTRACT_COMPLEX_TASK_ONLY]:
# in the task extraction mode
if has_complex_op or env.tracing_mode == TracingMode.EXTRACT_TASK:
engine = relay.backend.compile_engine.get()
ccache_key = engine.get_current_ccache_key()
env.add_workload_key(key, ccache_key)
schedule = te.create_schedule([x.op for x in outs])
elif env.tracing_mode == TracingMode.PREPARE_LAYOUT_REWRITE:
# todo(merrymercy, minminsun): port layout rewrite
Expand Down
14 changes: 11 additions & 3 deletions python/tvm/auto_scheduler/workload_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name

"""
Workload registration and serialization.
Expand All @@ -29,12 +30,14 @@
When we need the dag, we decode the string and call the function, which will return the dag.
"""

import logging
import pickle
import json

import tvm._ffi
from .utils import serialize_args, deserialize_args, get_func_name

logger = logging.getLogger("auto_scheduler")

# Global workload function and hash key registry
# It stores two types of workload:
Expand Down Expand Up @@ -105,13 +108,18 @@ def register_workload_tensors(tensors):

Returns
-------
key: str
The workload key
key: Optional[str]
The workload key, or None if failed to create a compute DAG.
"""
# pylint: disable=import-outside-toplevel
from .compute_dag import ComputeDAG

key = ComputeDAG(tensors).hash_key()
try:
key = ComputeDAG(tensors).hash_key()
except tvm.error.TVMError as err:
logger.info("Failed to create a ComputeDAG for auto_scheduler: %s", str(err))
return None

WORKLOAD_FUNC_REGISTRY[key] = tensors
return json.dumps((key,))

Expand Down
24 changes: 7 additions & 17 deletions python/tvm/relay/backend/compile_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import logging
import numpy as np
import tvm
from tvm import te, autotvm, auto_scheduler
from tvm import te, autotvm
from tvm.runtime import Object
from tvm.support import libinfo
from tvm.target import Target
Expand Down Expand Up @@ -196,25 +196,13 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True)
outs = best_plevel_impl.compute(attrs, inputs, out_type)
return best_plevel_impl, outs

# If auto-scheduler is enabled for Relay, always prefer auto-scheduler
if auto_scheduler.is_relay_integration_enabled():
auto_scheduler_impls = []
for impl in all_impls:
if impl.name.endswith(auto_scheduler.relay_integration.auto_schedule_impl_suffix):
auto_scheduler_impls.append(impl)

if auto_scheduler_impls:
assert len(auto_scheduler_impls) == 1
impl = auto_scheduler_impls[0]
outs = impl.compute(attrs, inputs, out_type)
return impl, outs

# Otherwise, try autotvm templates
outputs = {}
workloads = {}
best_autotvm_impl = None
best_cfg = None
dispatch_ctx = autotvm.task.DispatchContext.current
old_silent = autotvm.GLOBAL_SCOPE.silent
autotvm.GLOBAL_SCOPE.silent = True
for impl in all_impls:
outs = impl.compute(attrs, inputs, out_type)
Expand All @@ -232,7 +220,7 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True)
if best_cfg is None or best_cfg.cost > cfg.cost:
best_autotvm_impl = impl
best_cfg = cfg
autotvm.GLOBAL_SCOPE.silent = False
autotvm.GLOBAL_SCOPE.silent = old_silent

if best_autotvm_impl:
# The best autotvm implementation definitely doesn't use fallback config
Expand All @@ -251,7 +239,10 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True)
"is used, which may bring great performance regression."
% (target, workloads[best_plevel_impl])
)
if msg not in autotvm.task.DispatchContext.warning_messages:
if (
not autotvm.env.GLOBAL_SCOPE.silent
and msg not in autotvm.task.DispatchContext.warning_messages
):
autotvm.task.DispatchContext.warning_messages.add(msg)
autotvm_logger.warning(msg)
logger.info(
Expand Down Expand Up @@ -300,7 +291,6 @@ def lower_call(call, inputs, target):
best_impl, outputs = select_implementation(op, call.attrs, inputs, ret_type, target)
else:
# TODO(@icemelon9): Allow tvm to generate multiple kernels for dynamic shapes.
# Currently, we just use the implementation with highest plevel
best_impl, outputs = select_implementation(
op, call.attrs, inputs, ret_type, target, use_autotvm=False
)
Expand Down
Loading