diff --git a/.gitignore b/.gitignore index b9357018a64c..3c03e8ecda7a 100644 --- a/.gitignore +++ b/.gitignore @@ -196,6 +196,7 @@ tvm_t.* .python_history .pytest_cache .local +cmake-build-debug # Visual Studio Code .vscode @@ -233,3 +234,6 @@ conda/pkg # antlr files *.tokens *.interp + +# ansor tuning logs +scripts/*.json diff --git a/CMakeLists.txt b/CMakeLists.txt index d7faa8a4b666..5550b5f6b3a8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -185,6 +185,7 @@ assign_source_group("Include" ${GROUP_INCLUDE}) # Source file lists file(GLOB_RECURSE COMPILER_SRCS + src/ansor/*.cc src/node/*.cc src/ir/*.cc src/arith/*.cc diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py new file mode 100644 index 000000000000..93a82f073ac3 --- /dev/null +++ b/python/tvm/ansor/__init__.py @@ -0,0 +1,35 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-import, redefined-builtin +"""Namespace for Ansor auto-scheduler""" + +from . import compute_dag +from . import measure +from . import serialization +from . import loop_state +from . import utils +from . import workload_registry + +# Shortcut +from .compute_dag import ComputeDAG +from .auto_schedule import SearchTask, TuneOption, HardwareParams, \ + auto_schedule, EmptyPolicy +from .measure import MeasureInput, LocalBuilder, LocalRunner +from .serialization import LogToFile, LogReader, best_measure_pair_in_file, \ + load_from_file, write_measure_records_to_file +from .workload_registry import register_workload_func, \ + workload_key_to_dag, make_workload_key_func diff --git a/python/tvm/ansor/_ffi_api.py b/python/tvm/ansor/_ffi_api.py new file mode 100644 index 000000000000..e7b8a59eb83b --- /dev/null +++ b/python/tvm/ansor/_ffi_api.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Register FFI APIs from C++ for the namespace tvm.ansor""" +import tvm._ffi + + +tvm._ffi._init_api("ansor", __name__) diff --git a/python/tvm/ansor/auto_schedule.py b/python/tvm/ansor/auto_schedule.py new file mode 100644 index 000000000000..8fddac567529 --- /dev/null +++ b/python/tvm/ansor/auto_schedule.py @@ -0,0 +1,178 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""User interface for auto-scheduler""" + +import random + +import tvm._ffi +from tvm.runtime import Object +from .measure import LocalBuilder, LocalRunner +from . import _ffi_api + + +@tvm._ffi.register_object("ansor.HardwareParams") +class HardwareParams(Object): + """ The parameters of target hardware, this is used to guide the search process of + SearchPolicy. + + Parameters + ---------- + num_cores : int + vector_unit_bytes : int + cache_line_bytes : int + max_unroll_vec : int + max_innermost_split_factor : int + """ + def __init__(self, num_cores, vector_unit_bytes, cache_line_bytes, + max_unroll_vec, max_innermost_split_factor): + self.__init_handle_by_constructor__(_ffi_api.HardwareParams, num_cores, + vector_unit_bytes, cache_line_bytes, + max_unroll_vec, max_innermost_split_factor) + + +@tvm._ffi.register_object("ansor.SearchTask") +class SearchTask(Object): + """ The meta-information of a search task + + Parameters + ---------- + dag : ComputeDAG + workload_key : str + target : tvm.target.Target + target_host : tvm.target.Target + hardware_params : HardwareParams + """ + def __init__(self, dag, workload_key, target, target_host=None, + hardware_params=None): + self.__init_handle_by_constructor__(_ffi_api.SearchTask, dag, + workload_key, target, target_host, + hardware_params) + + +@tvm._ffi.register_object("ansor.SearchPolicy") +class SearchPolicy(Object): + """ The base class for search policy """ + + +@tvm._ffi.register_object("ansor.EmptyPolicy") +class EmptyPolicy(SearchPolicy): + """ This is an example empty search policy which will always generate + the init state of target ComputeDAG. + """ + def __init__(self): + self.__init_handle_by_constructor__(_ffi_api.EmptyPolicy) + + +@tvm._ffi.register_object("ansor.SearchCallback") +class SearchCallback(Object): + """ Callback function before or after search process """ + + +@tvm._ffi.register_object("ansor.TuneOption") +class TuneOption(Object): + """ The options for tuning + + Parameters + ---------- + n_trials: int + Number of total measurement trials + early_stopping: int + Stops early the tuning if no improvement after n measurements + num_measure_per_iter: int + The number of programs to be measured at each iteration + verbose: int + Verbosity level. 0 means silent. + builder: Builder + Builder which builds the program + runner: Runner + Runner which runs the program and measure time costs + measure_callbacks: List[MeasureCallback] + Callback functions called after each measure + Candidates: + - ansor.LogToFile + pre_search_callbacks: List[SearchCallback] + Callback functions called before the search process + Candidates: + - ansor.PreloadMeasuredStates + - ansor.PreloadCustomSketchRule + """ + def __init__(self, n_trials=0, early_stopping=-1, num_measure_per_iter=64, + verbose=1, builder='local', runner='local', measure_callbacks=None, + pre_search_callbacks=None): + if isinstance(builder, str): + if builder == 'local': + builder = LocalBuilder() + else: + raise ValueError("Invalid builder: " + builder) + + if isinstance(runner, str): + if runner == 'local': + runner = LocalRunner() + else: + raise ValueError("Invalid builder: " + runner) + + if measure_callbacks is None: + measure_callbacks = [] + + if pre_search_callbacks is None: + pre_search_callbacks = [] + + self.__init_handle_by_constructor__( + _ffi_api.TuneOption, n_trials, early_stopping, num_measure_per_iter, + verbose, builder, runner, measure_callbacks, pre_search_callbacks) + + +def auto_schedule(workload, target=None, + target_host=None, search_policy='default', + hardware_params=None, tune_option=None): + """ Do auto scheduling for a computation declaration. + + The workload parameter can be a `string` as workload_key, or directly + passing a `SearchTask` as input. + + Parameters + ---------- + workload : Union[SearchTask, str] + target : Target + target_host : Target = None + search_policy : Union[SearchPolicy, str] + hardware_params : HardwareParams + tune_option : TuneOption + + Returns + ------- + sch : tvm.Schedule + tensors : List[Tensor] + """ + if isinstance(search_policy, str): + if search_policy == 'default': + search_policy = EmptyPolicy() + else: + raise ValueError("Invalid search policy: " + search_policy) + + if tune_option is None: + tune_option = TuneOption(n_trials=0) + + if isinstance(workload, str): + sch, tensors = _ffi_api.AutoScheduleByWorkloadKey( + workload, target, target_host, search_policy, hardware_params, tune_option) + return sch, tensors + if isinstance(workload, SearchTask): + sch, tensors = _ffi_api.AutoScheduleBySearchTask(workload, search_policy, tune_option) + return sch, tensors + raise ValueError("Invalid workload: " + workload + ". Expect a string or SearchTask") diff --git a/python/tvm/ansor/compute_dag.py b/python/tvm/ansor/compute_dag.py new file mode 100644 index 000000000000..d591d615d1c5 --- /dev/null +++ b/python/tvm/ansor/compute_dag.py @@ -0,0 +1,77 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" Computational graph and its analysis tools """ + +import tvm._ffi +from tvm.runtime import Object +from .loop_state import State, StateObject +from . import _ffi_api + + +@tvm._ffi.register_object("ansor.ComputeDAG") +class ComputeDAG(Object): + """ + Computation declaration graph + + Parameters + ---------- + tensors : List[Tensor] + """ + def __init__(self, tensors): + self.__init_handle_by_constructor__(_ffi_api.ComputeDAG, tensors) + + def get_init_state(self): + """ Get init state of this ComputeDAG + + Returns + ------- + state : State + """ + return State(_ffi_api.ComputeDAGGetInitState(self), self) + + def apply_steps_from_state(self, state): + """ + Apply transform steps according to the history of a state + + Parameters + ---------- + state : StateObject + layout_rewrite_level : LayoutRewriteLevel + + Returns + ------- + sch : Schedule + args : List[Tensor] + """ + state_obj = state if isinstance(state, StateObject) else state.state_object + return _ffi_api.ComputeDAGApplyStepsFromState(self, state_obj) + + def print_python_code_from_state(self, state): + """ + Print transform steps in the history of a state as TVM's python schedule primitive + + Parameters + ---------- + state : StateObject + + Returns + ------- + str : Str + """ + state_obj = state if isinstance(state, StateObject) else state.state_object + return _ffi_api.ComputeDAGPrintPythonCodeFromState(self, state_obj) diff --git a/python/tvm/ansor/loop_state.py b/python/tvm/ansor/loop_state.py new file mode 100644 index 000000000000..bf81311ed664 --- /dev/null +++ b/python/tvm/ansor/loop_state.py @@ -0,0 +1,231 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-import + +""" +The definition of the "state" in search. A state consists a current loop structure +and the transform history to reach its current loop structure. +To enable flexible manipulation of the loop structure, we implemented a lightweight +loop structure IR (Intermediate Representation) specifically for search. + +Basically this is a simplified TVM IR with schedule primitives. +We don't use the existing TVM IR because +1. We want fast incremental change to the loop structures +2. We want serializable transformation history for replay, backtracking, and mutation +3. We may create some new macro schedule primitives + +After the search is done, we will lower this IR to TVM IR with TVM's schedule primitives. +Because we share a lot common objects during search, the transformation is +implemented in copy on write style. All objects are immutable, which is +similar to TVM IR. +""" + +import tvm._ffi +from tvm.te.tensor import Operation, Tensor +from tvm.runtime import Object +from . import _ffi_api + + +@tvm._ffi.register_object("ansor.Iterator") +class Iterator(Object): + """A for loop iterator""" + + +@tvm._ffi.register_object("ansor.Stage") +class Stage(Object): + """A stage in the compute declaration. Similar to tvm.te.schedule.Stage""" + + @property + def iters(self): + """ + Returns + ------- + iters : List[Iterator] + """ + if not hasattr(self, "iterators_cache"): + setattr(self, "iterators_cache", _ffi_api.StageGetIterators(self)) + return getattr(self, "iterators_cache") + + +@tvm._ffi.register_object("ansor.State") +class StateObject(Object): + """The internal State object """ + def __eq__(self, other): + return _ffi_api.StateEqual(self, other) + + +class State: + """ + A state in the search process. It consists of the current loop structure + and the history steps to reach this state. + + Notes + ----- + This is a wrapper class of StateObject to deal with copy-on-write property + """ + def __init__(self, state_object, dag): + self.state_object = state_object + self.compute_dag = dag + + self.stages_cache = None # A list to cache all stages + self.stage_id_map = {} # A dict maps operation to stage id + self._update_stage_id_map() + + @property + def stages(self): + """ + Returns + ------- + stages : List[Stage] + """ + if not self.stages_cache: + self.stages_cache = _ffi_api.StateGetStages(self.state_object) + return self.stages_cache + + @property + def stage_ops(self): + """ + Returns + ------- + ops: List[Operation] + """ + if not self.stages_cache: + self.stages_cache = _ffi_api.StateGetStages(self.state_object) + return [stage.op for stage in self.stages_cache] + + def transform_steps_size(self): + """ Return the size of transform_steps + """ + return _ffi_api.StateGetTransformStepsSize(self.state_object) + + def reorder(self, stage_id, order): + """ + Parameters + ---------- + stage_id : Union[int, Operation, Tensor] + The index of the stage to reorder + order : List[Iterator] + Iterators in the expected order + """ + stage_id = self._resolve_stage_id(stage_id) + + self.state_object = _ffi_api.StateReorder(self.state_object, stage_id, order) + self._clear_cache() + + def split(self, stage_id, iterator, lengths, inner_to_outer=True): + """ + Parameters + ---------- + stage_id : Union[int, Operation, Tensor] + The index of the stage to split + iterator : Iterator + The iterator to split + lengths: List[int] + The split factors + inner_to_outer: bool + True to use `factor` to split from inner to outer, + False to use `nparts` to split from outer to inner + + Returns + ------- + res_its : List[Iterator] + The splitted new Iterators + """ + stage_id = self._resolve_stage_id(stage_id) + + self.state_object, res = _ffi_api.StateSplit(self.state_object, stage_id, iterator, lengths, + inner_to_outer) + self._clear_cache() + return res + + def fuse(self, stage_id, iters): + """ + Parameters + ---------- + stage_id : Union[int, Operation, Tensor] + The index of the stage to fuse + iters : List[Iterator] + The iterators to be fused + + Returns + ------- + res_it : Iterator + The fused Iterator + """ + stage_id = self._resolve_stage_id(stage_id) + + self.state_object, res = _ffi_api.StateFuse(self.state_object, stage_id, iters) + self._clear_cache() + return res + + def _resolve_stage_id(self, stage_id): + if isinstance(stage_id, Operation): + return self.stage_id_map[stage_id] + elif isinstance(stage_id, tvm.te.Tensor): + return self.stage_id_map[stage_id.op] + elif isinstance(stage_id, int): + return stage_id + else: + raise ValueError("Invalid stage_id") + + def _update_stage_id_map(self): + if not self.stages_cache: + self.stages_cache = _ffi_api.StateGetStages(self.state_object) + for index, stage in enumerate(self.stages_cache): + self.stage_id_map[stage.op] = index + + def _insert_new_stage(self, new_stage_id): + new_stage_id = int(new_stage_id) + self.stages_cache = _ffi_api.StateGetStages(self.state_object) + added_op = self.stages_cache[new_stage_id].op + + # Add a new stage will change all ops. But we still want to use the old ops to index stages, + # So we keep updating them and do not remove the old ops. + + # Update stage_id_map for old ops, so we can still use the old ops to index stages. + for key, value in self.stage_id_map.items(): + if value >= new_stage_id: + self.stage_id_map[key] = value + 1 + self.stage_id_map[added_op] = new_stage_id + + # Update stage_id_map for new ops + self._update_stage_id_map() + + return added_op + + def _clear_cache(self): + self.stages_cache = None + + def copy(self): + state = State(self.state_object, self.compute_dag) + state.stage_id_map = self.stage_id_map.copy() + return state + + def __getitem__(self, key): + if not self.stages_cache: + self.stages_cache = _ffi_api.StateGetStages(self.state_object) + if isinstance(key, Tensor): + key = key.op + if isinstance(key, Operation): + return self.stages_cache[self.stage_id_map[key]] + raise ValueError("Item must be Tensor") + + def __str__(self): + return str(self.state_object) + + def __eq__(self, other): + return _ffi_api.StateEqual(self.state_object, other.state_object) diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py new file mode 100644 index 000000000000..af0eddc59653 --- /dev/null +++ b/python/tvm/ansor/measure.py @@ -0,0 +1,355 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Distributed measurement infrastructure to measure the runtime costs of tensor programs + +These functions are responsible for building the tvm module, uploading it to +remote devices, recording the running time costs, and checking the correctness of the output. + +We implement these in python to utilize python's multiprocessing and error handling +""" +from typing import List +import os +import time +import shutil +import logging +import traceback +import tempfile +import multiprocessing + +import tvm._ffi +from tvm.runtime import Object, module, ndarray +from tvm.driver import build_module +from tvm.ir import transform + +from . import _ffi_api +from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout + +LOGGER = logging.getLogger('ansor') + +# The maximum length of error message +MAX_ERROR_MSG_LEN = 512 + + +@tvm._ffi.register_object("ansor.MeasureCallback") +class MeasureCallback(Object): + """Base class for measurement callback function""" + + +@tvm._ffi.register_object("ansor.MeasureInput") +class MeasureInput(Object): + """ + Parameters + ---------- + task : SearchTask + state : State + """ + + def __init__(self, task, state): + self.__init_handle_by_constructor__(_ffi_api.MeasureInput, task, state.state_object) + + +@tvm._ffi.register_object("ansor.BuildResult") +class BuildResult(Object): + """ + Parameters + ---------- + filename : Str + args : List[Tensor] + error_no : Int + error_msg : Str + time_cost : Float + """ + + def __init__(self, filename, args, error_no, error_msg, time_cost): + self.__init_handle_by_constructor__( + _ffi_api.BuildResult, filename if filename else "", args, error_no, + error_msg if error_msg else "", time_cost) + + +@tvm._ffi.register_object("ansor.MeasureResult") +class MeasureResult(Object): + """ + Parameters + ---------- + costs : List[Float] + error_no : Int + error_msg : Str + all_cost : Float + timestamp : Float + """ + + def __init__(self, costs, error_no, error_msg, all_cost, timestamp): + self.__init_handle_by_constructor__( + _ffi_api.MeasureResult, costs, error_no, + error_msg if error_msg else "", all_cost, timestamp) + + +@tvm._ffi.register_object("ansor.Builder") +class Builder(Object): + """ Base class of Builder + """ + def build(self, measure_inputs, verbose=1): + """ + Parameters + ---------- + measure_inputs : List[MeasureInput] + verbost : Int + + Returns + ------- + res : List[BuildResult] + """ + return _ffi_api.BuilderBuild(self, measure_inputs, verbose) + + +@tvm._ffi.register_object("ansor.Runner") +class Runner(Object): + """ Base class of Runner + """ + def run(self, measure_inputs, build_results, verbose=1): + """ + Parameters + ---------- + measure_inputs : List[MeasureInput] + build_results : List[BuildResult] + + Returns + ------- + res : List[MeasureResult] + """ + return _ffi_api.RunnerRun(self, measure_inputs, build_results, verbose) + + +@tvm._ffi.register_object("ansor.LocalBuilder") +class LocalBuilder(Builder): + """ + Parameters + ---------- + timeout : Int + n_parallel : Int + build_func : Str + """ + + def __init__(self, + timeout=15, + n_parallel=multiprocessing.cpu_count(), + build_func='default'): + self.__init_handle_by_constructor__( + _ffi_api.LocalBuilder, timeout, n_parallel, build_func) + + +@tvm._ffi.register_object("ansor.LocalRunner") +class LocalRunner(Runner): + """ + Parameters + ---------- + timeout : Int + number : Int + repeat : Int + min_repeat_ms : Int + cooldown_interval : Float + """ + + def __init__(self, + timeout=10, + number=3, + repeat=1, + min_repeat_ms=0, + cooldown_interval=0.0): + self.__init_handle_by_constructor__( + _ffi_api.LocalRunner, timeout, number, repeat, min_repeat_ms, cooldown_interval) + + +class MeasureErrorNo(object): + """Error type for MeasureResult""" + NO_ERROR = 0 # No error + INSTANTIATION_ERROR = 1 # Errors happen when apply transform steps from init state + # Errors happen when compiling code on host (e.g. tvm.build) + COMPILE_HOST = 2 + COMPILE_DEVICE = 3 # Errors happen when compiling code on device + # (e.g. OpenCL JIT on the device) + RUNTIME_DEVICE = 4 # Errors happen when run program on device + WRONG_ANSWER = 5 # Answer is wrong when compared to a reference output + BUILD_TIMEOUT = 6 # Timeout during compilation + RUN_TIMEOUT = 7 # Timeout during run + UNKNOWN_ERROR = 8 # Unknown error + + +def make_error_msg(): + """Get the error message from traceback""" + error_msg = str(traceback.format_exc()) + if len(error_msg) > MAX_ERROR_MSG_LEN: + error_msg = error_msg[:MAX_ERROR_MSG_LEN//2] + \ + "\n...\n" + error_msg[-MAX_ERROR_MSG_LEN//2:] + return error_msg + + +global global_build_arguments +global global_run_arguments + + +def local_build_worker(index): + """ Local builder function + """ + # We use fork to copy arguments from a global variable. + # This can avoid expensive serialization of TVM IR when using multiprocessing.Pool + measure_inputs, build_func, timeout, verbose = global_build_arguments + assert isinstance(build_func, str) + if build_func == 'default': + build_func = tar.tar + elif build_func == 'ndk': + build_func = ndk.create_shared + else: + raise ValueError("Invalid build_func" + build_func) + + def timed_func(): + tic = time.time() + inp = measure_inputs[index] + task = inp.task + + error_no = MeasureErrorNo.NO_ERROR + error_msg = None + args = [] + + try: + sch, args = task.compute_dag.apply_steps_from_state( + inp.state) + except Exception: + error_no = MeasureErrorNo.INSTANTIATION_ERROR + error_msg = make_error_msg() + + if error_no == 0: + dirname = tempfile.mkdtemp() + filename = os.path.join( + dirname, "tmp_func." + build_func.output_format) + + try: + with transform.PassContext(): # todo(lmzheng): port the unroll pass + func = build_module.build( + sch, args, target=task.target, target_host=task.target_host) + func.export_library(filename, build_func) + except Exception: + error_no = MeasureErrorNo.COMPILE_HOST + error_msg = make_error_msg() + else: + filename = "" + + if verbose >= 1: + if error_no == MeasureErrorNo.NO_ERROR: + print(".", end="") + else: + print(".E", end="") # Build error + return filename, args, error_no, error_msg, time.time() - tic + + res = call_func_with_timeout(timeout, timed_func) + if isinstance(res, TimeoutError): + if verbose >= 1: + print(".T", end="") # Build timeout + res = None, [], MeasureErrorNo.BUILD_TIMEOUT, None, timeout + + return res + + +@tvm._ffi.register_func("ansor.local_builder.build") +def local_builder_build(inputs: List[MeasureInput], timeout: float, n_parallel: int, + build_func: str, verbose: int): + """ Local builder build function + """ + # We use fork to copy arguments from a global variable. + # This can avoid expensive serialization of TVM IR when using multiprocessing.Pool + global global_build_arguments + global_build_arguments = (inputs, build_func, timeout, verbose) + + pool = NoDaemonPool(n_parallel) + tuple_res = pool.map(local_build_worker, range(len(inputs))) + pool.terminate() + pool.join() + del pool + + results = [] + for res in tuple_res: + results.append(BuildResult(*res)) + + return results + +@tvm._ffi.register_func("ansor.local_runner.run") +def local_run(inputs: List[MeasureInput], build_results: List[BuildResult], + timeout: float, number: int, repeat: int, min_repeat_ms: int, + cooldown_interval: float, verbose: int): + """ ... + """ + MAX_FLOAT = 1e10 # We use 1e10 instead of sys.float_info.max for better readability in log + + def timed_func(inp, build_res): + tic = time.time() + error_no = 0 + error_msg = None + try: + func = module.load_module(build_res.filename) + ctx = ndarray.context(str(inp.task.target), 0) + time_f = func.time_evaluator( + func.entry_name, ctx, number=number, repeat=repeat, min_repeat_ms=min_repeat_ms) + except Exception: + costs = (MAX_FLOAT,) + error_no = MeasureErrorNo.COMPILE_DEVICE + error_msg = make_error_msg() + + if error_no == 0: + try: + args = [ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in + build_res.args] + ctx.sync() + + costs = time_f(*args).results + except Exception: + costs = (MAX_FLOAT,) + error_no = MeasureErrorNo.RUNTIME_DEVICE + error_msg = make_error_msg() + + shutil.rmtree(os.path.dirname(build_res.filename)) + toc = time.time() + time.sleep(cooldown_interval) + + if verbose >= 1: + if error_no == MeasureErrorNo.NO_ERROR: + print("*", end="") + else: + print("*E", end="") # Run error + return costs, error_no, error_msg, toc - tic + build_res.time_cost, toc + + measure_results = [] + assert len(inputs) == len(build_results), \ + "Measure input size should be equal to build results" + for inp, build_res in zip(inputs, build_results): + if build_res.error_no != 0: + res = (MAX_FLOAT,), build_res.error_no, build_res.error_msg, build_res.time_cost, \ + time.time() + else: + res = call_func_with_timeout( + timeout, timed_func, args=(inp, build_res)) + if isinstance(res, TimeoutError): + if verbose >= 1: + print("*T", end="") # Run timeout + res = (MAX_FLOAT,), MeasureErrorNo.RUN_TIMEOUT, None, \ + build_res.time_cost + timeout, time.time() + measure_results.append(MeasureResult(*res)) + + if verbose >= 1: + print("") + + return measure_results diff --git a/python/tvm/ansor/serialization.py b/python/tvm/ansor/serialization.py new file mode 100644 index 000000000000..1bd9d8cf64e6 --- /dev/null +++ b/python/tvm/ansor/serialization.py @@ -0,0 +1,120 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Serialization and other I/O support for tuning logs (measurement records)""" + +import numpy as np + +import tvm._ffi +from tvm.runtime import Object +from .measure import MeasureCallback, MeasureErrorNo +from .loop_state import State +from . import _ffi_api + + +@tvm._ffi.register_object("ansor.LogToFile") +class LogToFile(MeasureCallback): + """ + A measurement callback that writes measurement records into a file + + Parameters + ---------- + filename : Str + """ + + def __init__(self, filename="ansor_tuning.json"): + self.__init_handle_by_constructor__(_ffi_api.LogToFile, filename) + + +@tvm._ffi.register_object("ansor.LogReader") +class LogReader(Object): + """ + Reader of the json log file + + Parameters + ---------- + filename : Str + """ + def __init__(self, filename="ansor_tuning.json"): + self.__init_handle_by_constructor__(_ffi_api.LogReader, filename) + + def read_lines(self, max_size=-1, skip_size=0): + inputs, results = _ffi_api.LogReaderReadLines( + self, max_size, skip_size) + return inputs, results + + def __iter__(self): + while True: + ret = _ffi_api.LogReaderReadNext(self) + if ret is None or not len(ret): + break + yield ret[0], ret[1] # (input, result) + + +def load_from_file(filename: str): + """Load measurement records from a file""" + return zip(*LogReader(filename).read_lines()) + + +def write_measure_records_to_file(filename, inputs, results): + """Write(append) measure records to file""" + _ffi_api.WriteMeasureRecordsToFile(filename, inputs, results) + + +def get_states_from_measure_inputs(inputs, task): + """Get states from measure inputs""" + state_objects = _ffi_api.GetStatesFromMeasureInputs(inputs, task) + return [State(s, task.compute_dag) for s in state_objects] + + +def best_measure_pair_in_file(filename, workload_key=None, target=None): + """ Return the best measurement pair form a log file + + Parameters + ---------- + filename : Str + workload_key : Str + target : Str + + Returns + ------- + inp : MeasureInput + res : MeasureResult + """ + log_reader = LogReader(filename) + best_cost = 1e30 + best_inp = None + best_res = None + + for inp, res in log_reader: + if res.error_no != MeasureErrorNo.NO_ERROR: + continue + if workload_key and inp.task.workload_key != workload_key: + continue + if target and inp.task.target.target_name != target.target_name: + continue + + costs = [] + for value in res.costs: + costs.append(value.value) + cost = np.mean(costs) + if cost < best_cost: + best_cost = cost + best_inp = inp + best_res = res + + return best_inp, best_res diff --git a/python/tvm/ansor/utils.py b/python/tvm/ansor/utils.py new file mode 100644 index 000000000000..b406824ba811 --- /dev/null +++ b/python/tvm/ansor/utils.py @@ -0,0 +1,179 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Common utilities for ansor""" + +import multiprocessing +import multiprocessing.pool +import queue +import signal +import threading +import os + +import numpy as np +try: + import psutil +except ImportError: + psutil = None + +from tvm import rpc +from tvm.tir import expr +from tvm.tir.transform import Simplify +from tvm.ir.transform import Sequential + + +def get_func_name(func): + """Get name of a function + + Parameters + ---------- + func: Function + The function + Returns + ------- + name: str + The name + """ + return func.func_name if hasattr(func, 'func_name') else func.__name__ + + +def get_const_int(exp): + """Verifies expr is integer and get the constant value. + + Parameters + ---------- + exp : tvm.Expr or int + The input expression. + + Returns + ------- + out_value : int + The output. + """ + if isinstance(exp, int): + return exp + if not isinstance(exp, (expr.IntImm)): + opt = Sequential([Simplify()]) + exp = opt(exp) + if not isinstance(exp, (expr.IntImm)): + raise ValueError("Expect value to be constant int") + return exp.value + + +def get_const_tuple(in_tuple): + """Verifies input tuple is IntImm, returns tuple of int. + + Parameters + ---------- + in_tuple : tuple of Expr + The input. + + Returns + ------- + out_tuple : tuple of int + The output. + """ + return tuple(get_const_int(x) for x in in_tuple) + + +def to_str_round(x, decimal=6): + """Convert object to str and round float numbers""" + if isinstance(x, str): + return x + if isinstance(x, (list, tuple)) or isinstance(x, np.ndarray): + return "[" + ", ".join([to_str_round(y, decimal=decimal) + for y in x]) + "]" + if isinstance(x, dict): + return str({k: eval(to_str_round(v)) for k, v in x.items()}) + if isinstance(x, int): + return str(x) + if isinstance(x, (np.float32, np.float64, float)): + format_str = "%%.%df" % decimal + return format_str % x + raise ValueError("Invalid value: " + str(x) + "\ttype: " + str(type(x))) + + +def array_mean(arr): + """Mean function for tvm array (Array)""" + return sum(x.value for x in arr) / len(arr) + + +class NoDaemonProcess(multiprocessing.Process): + @property + def daemon(self): + return False + + @daemon.setter + def daemon(self, value): + pass + + +class NoDaemonContext(type(multiprocessing.get_context())): + Process = NoDaemonProcess + + +class NoDaemonPool(multiprocessing.pool.Pool): + """A no daemon pool version of multiprocessing.Pool. + This allows us to start new processings inside the worker function""" + + def __init__(self, *args, **kwargs): + kwargs['context'] = NoDaemonContext() + super().__init__(*args, **kwargs) + + +def kill_child_processes(parent_pid, sig=signal.SIGTERM): + """kill all child processes recursively""" + try: + parent = psutil.Process(parent_pid) + except psutil.NoSuchProcess: + return + children = parent.children(recursive=True) + for process in children: + try: + process.send_signal(sig) + except psutil.NoSuchProcess: + return + + +def call_func_with_timeout(timeout, func, args=(), kwargs=None): + """Call a function with timeout""" + def func_wrapper(que): + if kwargs: + que.put(func(*args, **kwargs)) + else: + que.put(func(*args)) + + que = multiprocessing.Queue(2) + process = multiprocessing.Process(target=func_wrapper, args=(que,)) + process.start() + process.join(timeout) + + try: + res = que.get(block=False) + except queue.Empty: + res = TimeoutError() + + # clean queue and process + kill_child_processes(process.pid) + process.terminate() + process.join() + que.close() + que.join_thread() + del process + del que + + return res diff --git a/python/tvm/ansor/workload_registry.py b/python/tvm/ansor/workload_registry.py new file mode 100644 index 000000000000..025b5f03c661 --- /dev/null +++ b/python/tvm/ansor/workload_registry.py @@ -0,0 +1,190 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Workload registration and serialization. + +We use a json string to represent a workload (a compute dag). +The format of the string is `[func_name, [args...]]`. +The dag should be the return value of this `func_name(*args)`. + +Rationale: The workload is actually a compute dag defined by tvm dsl. But serializing compute dags +and matching them efficiently is not easy. Therefore, we use the above string to encode a compute +dag. +These strings are efficient for serialization/matching and wont' be too long. +When we need the dag, we decode the string and call the function, which will return the dag. +""" + +from typing import List, Tuple, Callable, Union +from collections import Hashable +import pickle +import json +import hashlib + +import tvm._ffi +from ..te import Tensor, PlaceholderOp, ComputeOp, placeholder +from .utils import get_const_tuple +from .compute_dag import ComputeDAG + +WORKLOAD_FUNC_REGISTRY = {} + + +def register_workload_func(func: Callable): + """Register a workload generation function + The input function should take hashable and jsonable arguments + (int, float, tuple of int, tvm.tensor.Tensor, ...) and return a list of tvm.tensor.Tensor. + + Examples + -------- + @register_workload_func + def matmul(N, M, K): + A = te.placeholder((N, K), name='A') + B = te.placeholder((K, M), name='B') + k = te.reduce_axis((0, K), name='k') + C = te.compute((N, M), lambda i, j: tvm.sum(A[i][k] * B[k][j], axis=[k]), name='C') + return [A, B, C] + """ + func_name = func.__name__ + if func_name in WORKLOAD_FUNC_REGISTRY: + raise RuntimeError('%s has been registered already' % func_name) + WORKLOAD_FUNC_REGISTRY[func_name] = func + return func + + +def compute_dag_hash(dag: ComputeDAG): + """ Get hash value for a ComputeDAG + """ + # todo: implement this more carefully and move this to c++ as a member function of ComputeDAG + str_key = '' + for op in dag.ops: + t = op.output(0) + if isinstance(op, PlaceholderOp): + str_key += 'placeholder,' + str_key += str(get_const_tuple(t.shape)) + ',' + str_key += t.dtype + ';' + elif isinstance(op, ComputeOp): + str_key += str(t.op.body) + ',' + str_key += str(get_const_tuple(t.shape)) + ',' + str_key += t.dtype + ';' + else: + raise ValueError("Invalid op: " + op) + + str_key = str_key.encode(encoding='utf-8') + return hashlib.md5(str_key).hexdigest() + + +def register_workload_bufs(bufs: List[Tensor]) -> str: + """Directly register buffers of a workload and return the workload_key + The buffers can be looked up with workload_key_to_tensors by the workload_key + """ + dag = ComputeDAG(bufs) + key = compute_dag_hash(dag) + WORKLOAD_FUNC_REGISTRY[key] = bufs + return json.dumps((key,)) + + +def list_to_tuple(x: List) -> Tuple: + """Convert a list to a tuple recursively""" + assert isinstance(x, list) + return tuple(list_to_tuple(y) if isinstance(y, list) else y for y in x) + + +def serialize_args(args: Tuple) -> Tuple: + """ + Serialize arguments of a function to a hashable and jsonable tuple. + Currently this is mainly used for tvm.tensor.Tensor + """ + ret = [] + for t in args: + if isinstance(t, Tensor): + t = ('TENSOR', get_const_tuple(t.shape), t.dtype) + elif isinstance(t, list): + t = list_to_tuple(t) + + assert isinstance(t, Hashable), str(t) + " is not hashable" + ret.append(t) + + return tuple(ret) + + +def deserialize_args(args: Tuple) -> List: + """The inverse function of :code:`serialize_args`""" + ret = [] + for t in args: + if isinstance(t, (tuple, list)) and t[0] == 'TENSOR': + ret.append(placeholder(shape=t[1], dtype=t[2])) + else: + ret.append(t) + return ret + + +@tvm._ffi.register_func("ansor.workload_key_to_tensors") +def workload_key_to_tensors(workload_key: str) -> List[Tensor]: + """Decode a workload key to the input/output tensors""" + workload = json.loads(workload_key) + name = workload[0] + lookup = WORKLOAD_FUNC_REGISTRY[name] + + if callable(lookup): + args = deserialize_args(workload[1:]) + return lookup(*args) + return lookup + + +@ tvm._ffi.register_func("ansor.workload_key_to_dag") +def workload_key_to_dag(workload_key: str) -> ComputeDAG: + """Decode a workload key to a compute dag""" + tensors = workload_key_to_tensors(workload_key) + return ComputeDAG(tensors) + + +def make_workload_key_func(func: Union[str, Callable], args: Tuple) -> str: + """make a workload key from function and arguments""" + args = serialize_args(args) + + if callable(func): + func_name = func.__name__ + elif isinstance(func, str): + func_name = func + else: + raise ValueError("Invalid function: " + str(func)) + + assert func_name in WORKLOAD_FUNC_REGISTRY, \ + "%s is not registered. Please register it with register_auto_scheduler_workload_func" % func + + return json.dumps((func_name,) + args) + + +def make_workload_key_bufs(bufs: List[Tensor]) -> str: + """make a workload key from bufs""" + dag = ComputeDAG(bufs) + key = compute_dag_hash(dag) + return json.dumps((key,)) + + +def dump_workload_func_registry(filename: str): + """Dump workload function registry to a pickle binary file""" + global WORKLOAD_FUNC_REGISTRY + + pickle.dump(WORKLOAD_FUNC_REGISTRY, open(filename, 'wb')) + + +def load_workload_func_registry(filename: str): + """Load workload function registry from a pickle binary file""" + global WORKLOAD_FUNC_REGISTRY + + WORKLOAD_FUNC_REGISTRY = pickle.load(open(filename, 'rb')) diff --git a/python/tvm/rpc/server.py b/python/tvm/rpc/server.py index 15a3c7de789d..42bcb00a9117 100644 --- a/python/tvm/rpc/server.py +++ b/python/tvm/rpc/server.py @@ -348,7 +348,8 @@ def __init__(self, cmd = [sys.executable, "-m", "tvm.exec.rpc_server", "--host=%s" % host, - "--port=%s" % port] + "--port=%s" % port, + "--port-end=%s" % port_end] if tracker_addr: assert key cmd += ["--tracker=%s:%d" % tracker_addr, diff --git a/src/ansor/auto_schedule.cc b/src/ansor/auto_schedule.cc new file mode 100644 index 000000000000..82ec07930adc --- /dev/null +++ b/src/ansor/auto_schedule.cc @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ansor/auto_schedule.cc + * \brief The user interface of the auto-scheduler + */ + +#include "auto_schedule.h" +#include +#include +#include + +namespace tvm { +namespace ansor { + +TVM_REGISTER_NODE_TYPE(TuneOptionNode); + +TuneOption::TuneOption(int n_trials, int early_stopping, + int num_measure_per_iter, int verbose, Builder builder, + Runner runner, Array measure_callbacks, + Array pre_search_callbacks) { + auto node = make_object(); + node->n_trials = n_trials; + node->early_stopping = early_stopping; + node->num_measure_per_iter = num_measure_per_iter; + node->verbose = verbose; + node->builder = std::move(builder); + node->runner = std::move(runner); + node->measure_callbacks = std::move(measure_callbacks); + node->pre_search_callbacks = std::move(pre_search_callbacks); + data_ = std::move(node); +} + +std::pair > AutoSchedule(SearchTask task, + SearchPolicy search_policy, TuneOption tune_option) { + // Search for the best schedule + ProgramMeasurer measurer = + ProgramMeasurer(tune_option->builder, tune_option->runner, + tune_option->measure_callbacks, + tune_option->verbose); + + State state = search_policy->Search( + task, tune_option->n_trials, tune_option->early_stopping, + tune_option->num_measure_per_iter, tune_option->verbose, measurer, + tune_option->pre_search_callbacks); + + return task->compute_dag.ApplySteps(state->transform_steps); +} + +std::pair > AutoSchedule( + std::string workload_key, Target target, Target target_host, + SearchPolicy search_policy, HardwareParams hardware_params, + TuneOption tune_option) { + ComputeDAG dag = ComputeDAG(workload_key); + SearchTask task = SearchTask( + std::move(dag), std::move(workload_key), std::move(target), + std::move(target_host), std::move(hardware_params)); + return AutoSchedule(std::move(task), std::move(search_policy), + std::move(tune_option)); +} + +TVM_REGISTER_GLOBAL("ansor.TuneOption") +.set_body_typed([](int n_trials, int early_stopping, + int num_measure_per_iter, int verbose, Builder builder, + Runner runner, Array measure_callbacks, + Array pre_search_callbacks) { + return TuneOption(n_trials, early_stopping, num_measure_per_iter, verbose, + builder, runner, measure_callbacks, pre_search_callbacks); +}); + +TVM_REGISTER_GLOBAL("ansor.AutoScheduleBySearchTask") +.set_body_typed([](SearchTask task, SearchPolicy search_policy, + TuneOption tune_option) { + te::Schedule sch; + Array return_tensors; + std::tie(sch, return_tensors) = AutoSchedule(task, search_policy, tune_option); + + return Array{sch, return_tensors}; +}); + +TVM_REGISTER_GLOBAL("ansor.AutoScheduleByWorkloadKey") +.set_body_typed([](std::string workload_key, Target target, + Target target_host, SearchPolicy search_policy, + HardwareParams hardware_params, TuneOption tune_option) { + te::Schedule sch; + Array return_tensors; + std::tie(sch, return_tensors) = + AutoSchedule(workload_key, target, target_host, search_policy, + hardware_params, tune_option); + + return Array{sch, return_tensors}; +}); + +} // namespace ansor +} // namespace tvm diff --git a/src/ansor/auto_schedule.h b/src/ansor/auto_schedule.h new file mode 100644 index 000000000000..7ffd2c4d3a70 --- /dev/null +++ b/src/ansor/auto_schedule.h @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ansor/auto_schedule.h + * \brief The user interface of the auto-scheduler + */ + +#ifndef TVM_ANSOR_AUTO_SCHEDULE_H_ +#define TVM_ANSOR_AUTO_SCHEDULE_H_ + +#include +#include +#include "measure.h" +#include "search_policy/search_policy.h" + +namespace tvm { +namespace ansor { + +/*! \brief Tuning and measurement options */ +class TuneOptionNode : public Object { + public: + int n_trials; // Number of total measurement trials + int early_stopping; // Stops early the tuning if no improvement after n measurements + int num_measure_per_iter; // The number of programs to be measured at each iteration + int verbose; // Verbosity level. 0 means silent. + Builder builder; // Builder which builds the program + Runner runner; // Runner which runs the program and measure time costs + Array measure_callbacks; // MeasureCallback functions + Array pre_search_callbacks; // SearchCallback functions + // run before search + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("n_trials", &n_trials); + v->Visit("early_stopping", &early_stopping); + v->Visit("num_measure_per_iter", &num_measure_per_iter); + v->Visit("verbose", &verbose); + v->Visit("builder", &builder); + v->Visit("runner", &runner); + v->Visit("measure_callbacks", &measure_callbacks); + v->Visit("pre_search_callbacks", &pre_search_callbacks); + } + + static constexpr const char* _type_key = "ansor.TuneOption"; + TVM_DECLARE_FINAL_OBJECT_INFO(TuneOptionNode, Object); +}; + +/*! + * \brief Managed reference to TuneOptionNode. + * \sa TuneOptionNode + */ +class TuneOption : public ObjectRef { + public: + TuneOption(int n_trials, int early_stopping, int num_measure_per_iter, + int verbose, Builder builder, Runner runner, + Array measure_callbacks, + Array pre_search_callbacks); + + TVM_DEFINE_OBJECT_REF_METHODS(TuneOption, ObjectRef, TuneOptionNode); +}; + +/*! \brief Auto schedule for a compute declaration */ +std::pair > AutoSchedule( + SearchTask task, SearchPolicy search_policy, TuneOption tune_option); + +/*! \brief Auto schedule for a compute declaration */ +std::pair > AutoSchedule( + std::string workload_key, Target target, Target target_host, + SearchPolicy search_policy, HardwareParams hardware_params, + TuneOption tune_option); + +} // namespace ansor +} // namespace tvm + +#endif // TVM_ANSOR_AUTO_SCHEDULE_H_ diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc new file mode 100644 index 000000000000..7638f98e65ea --- /dev/null +++ b/src/ansor/compute_dag.cc @@ -0,0 +1,951 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ansor/compute_dag.cc + * \brief Compute declaration graph and its related analysis tools + */ + +#include "compute_dag.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "transform_step.h" + +namespace tvm { +namespace ansor { + +using namespace tvm::tir; + +TVM_REGISTER_NODE_TYPE(ComputeDAGNode); + +template +using OperationMap = AccessAnalyzerNode::OperationMap; + +using OperationSet = std::unordered_set; + +// Topo-sort ops from tensors according to their read-write relations. +// Results are stored in ops +void TopoSortOps(const Array& tensors, + std::vector* ops) { + std::unordered_map degree; + std::unordered_map > edge_set; + std::unordered_map priority; + std::unordered_set visited; + + // traverse to build edge_set and count degree + std::vector stack; + stack.reserve(tensors.size()); + for (const auto& x : tensors) { + stack.push_back(x->op.operator->()); + } + + int ct = 0; + while (!stack.empty()) { + const te::OperationNode* op = stack.back(); + stack.pop_back(); + if (visited.count(op)) { + continue; + } + + priority[op] = ct; + ct++; + visited.insert(op); + + if (op->IsInstance()) { + degree[op] = 0; + } else if (auto cop = GetRef(op).as()) { + const Array& input_tensors = cop->InputTensors(); + degree[op] = input_tensors.size(); + for (const auto& ten : input_tensors) { + edge_set[ten->op.operator->()].push_back(op); + stack.push_back(ten->op.operator->()); + } + } else { + LOG(FATAL) << "Unsupported op " << GetRef(op); + } + } + + // topo sort + ops->clear(); + + using Item = std::pair; + auto cmp = [](const Item& left, const Item& right) { + return left.second < right.second; + }; + std::priority_queue, decltype(cmp)> queue(cmp); + for (const auto& iter : degree) { + if (iter.second == 0) { + queue.push(Item(iter.first, priority[iter.first])); + } + } + + ops->reserve(degree.size()); + while (!queue.empty()) { + Item item = queue.top(); + queue.pop(); + ops->push_back(GetRef(item.first)); + for (const auto& dst : edge_set[item.first]) { + degree[dst] -= 1; + if (degree[dst] == 0) { + queue.push(Item(dst, priority[dst])); + } + } + } +} + +// Extract all tensor accesses in an expr +class TensorAccessExtractor : public StmtExprVisitor { + public: + void Extract(PrimExpr expr) { + this->VisitExpr(expr); + } + + void VisitExpr_(const CallNode* op) final { + if (op->name == tir::intrinsic::tvm_if_then_else) { + has_branch = true; + } + StmtExprVisitor::VisitExpr_(op); + } + + void VisitExpr_(const ProducerLoadNode* op) final { + buf_accesses[Downcast(op->producer)->op].emplace_back( + op->indices.begin(), op->indices.end()); + StmtExprVisitor::VisitExpr_(op); + } + + void VisitStmt_(const IfThenElseNode* op) final { + has_branch = true; + StmtExprVisitor::VisitStmt_(op); + } + + void VisitExpr_(const SelectNode* op) final { + has_branch = true; + StmtExprVisitor::VisitExpr_(op); + } + + OperationMap > > buf_accesses; + bool has_branch{false}; +}; + +// Returns whether the expr equals to the var with a const shift +bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) { + if (auto pv = expr.as()) { + return pv == var.get(); + } else if (auto padd = expr.as()) { + return ((padd->a.get() == var.get() && padd->b->IsInstance()) || + (padd->b.get() == var.get() && padd->a->IsInstance())); + } else if (auto psub = expr.as()) { + return ((psub->a.get() == var.get() && psub->b->IsInstance()) || + (psub->b.get() == var.get() && psub->a->IsInstance())); + } else { + return false; + } +} + +// Return whether the access is injective +bool IsInjective(const te::Operation& op, const std::vector& index, + bool* axis_missing, bool* axis_duplicated, bool* same_order) { + auto cop = op.as(); + if (cop == nullptr) { return false; } + + std::vector index_to_var_idx; + std::vector var_idx_ct(cop->axis.size(), 0); + + for (const auto& expr : index) { + if (!is_const(expr)) { + bool found = false; + for (size_t i = 0; i < cop->axis.size(); ++i) { + if (IsConstShiftEqual(cop->axis[i]->var, expr)) { + index_to_var_idx.push_back(i); + var_idx_ct[i]++; + found = true; + break; + } + } + if (!found) { + return false; + } + } + } + + *axis_missing = false; // Some axes are missing + *axis_duplicated = false; // Some axes appear more than once + *same_order = true; // The axis order is the same as op->axis + for (int ct : var_idx_ct) { + if (ct == 0) { + *axis_missing = true; + } else if (ct > 1) { + *axis_duplicated = true; + } + } + for (size_t i = 1; i < index_to_var_idx.size(); ++i) { + if (index_to_var_idx[i] < index_to_var_idx[i - 1]) { + *same_order = false; + break; + } + } + + return true; +} + +// Gather all VarNodes in an expr +static void GatherVars(const PrimExpr& expr, + std::unordered_set* vars) { + PostOrderVisit(expr, [&vars](const ObjectRef &node) { + if (const VarNode* op = node.as()) { + vars->insert(op); + } + }); +} + +// Check whether an expr has expensive operations (e.g. exp) +static bool HasExpensiveOp(const PrimExpr& expr) { + bool found = false; + PostOrderVisit(expr, [&found](const ObjectRef &node) { + if (const CallNode* op = node.as()) { + if (op->call_type == CallNode::CallType::PureIntrinsic && + op->name == "exp") { + found = true; + } + } + }); + return found; +} + +AccessAnalyzer::AccessAnalyzer(const Array& tensors) { + auto node = make_object(); + OperationMap has_branch; + + // get all ops + TopoSortOps(tensors, &node->ops_topo_order); + + // build read & write access map + for (const auto& op : node->ops_topo_order) { + if (op->IsInstance()) { + node->read_from[op] = + OperationMap > >(); + } else if (auto cop = op.as()) { + TensorAccessExtractor extractor; + for (const auto& exp : cop->body) { + extractor.Extract(exp); + } + + for (const auto& iter : extractor.buf_accesses) { + std::vector >& accesses = + node->read_by[iter.first][op]; + accesses.insert(accesses.begin(), iter.second.begin(), + iter.second.end()); + } + + node->read_from[op] = std::move(extractor.buf_accesses); + has_branch[op] = extractor.has_branch; + } else { + LOG(FATAL) << "Invalid op: " << op; + } + } + + // do some static analysis + for (const auto& op : node->ops_topo_order) { + if (op->IsInstance()) { + node->is_injective[op] = true; + node->needs_multi_level_tiling[op] = false; + node->is_strict_inlineable[op] = false; + node->is_output[op] = false; + } else if (auto pop = op.as()) { + // check whether is element-wise and strict-inlineable + // (see definition in compute_dag.h) + bool is_injective = true; + bool is_strict_inlineable = true; + + bool axis_missing, axis_duplicated, same_order; + for (const auto& pair : node->read_from[op]) { + const std::vector >& access = pair.second; + for (const auto& index : access) { + if (!ansor::IsInjective(op, index, &axis_missing, &axis_duplicated, + &same_order)) { + is_injective = false; + is_strict_inlineable = false; + break; + } + if (!same_order || axis_duplicated) { + // do not strictly inline transpose + is_strict_inlineable = false; + } + } + if (!is_injective) { break; } + } + if (has_branch[op]) { + is_strict_inlineable = false; + } + + // don't strictly inline expensive op (e.g. exp) + bool has_expensive_op = false; + for (const auto& expr : pop->body) { + has_expensive_op |= HasExpensiveOp(expr); + } + + node->is_injective[op] = is_injective; + node->is_strict_inlineable[op] = is_strict_inlineable && + !has_expensive_op; + + // check whether the op needs multi-level tiling + // (see definition in compute_dag.h) + bool needs_multi_level_tiling = false; + int n_missing = 0; + + for (const auto& pair : node->read_from[op]) { + const std::vector > &access = pair.second; + std::unordered_set vars; + for (const std::vector &indices : access) { + for (const PrimExpr& expr : indices) { + GatherVars(expr, &vars); + } + } + bool missing = false; + for (const auto& axis : pop->axis) { + if (GetIntImm(axis->dom->extent) > 1 && + vars.count(axis->var.get()) == 0) { + missing = true; + } + } + if (missing) { + n_missing++; + } + + if (n_missing >= 2 || (n_missing >= 1 && !pop->reduce_axis.empty())) { + needs_multi_level_tiling = true; + break; + } + } + + node->needs_multi_level_tiling[op] = needs_multi_level_tiling; + + // check whether is output + node->is_output[op] = node->read_by[op].empty(); + } else { + LOG(FATAL) << "Invalid op" << op; + } + } + + data_ = std::move(node); +} + +bool AccessAnalyzer::NeedsMultiLevelTiling(const te::Operation &op) const { + return operator->()->needs_multi_level_tiling.at(op); +} + +bool AccessAnalyzer::IsOutput(const te::Operation& op) const { + return operator->()->is_output.at(op); +} + +bool AccessAnalyzer::IsInjective(const te::Operation& op) const { + return operator->()->is_injective.at(op); +} + +bool AccessAnalyzer::IsStrictInlineable(const te::Operation &op) const { + return operator->()->is_strict_inlineable.at(op); +} + +void AccessAnalyzer::GetProducers(const State& state, const te::Operation& op, + OperationSet* producers) const { + producers->clear(); + for (const auto& iter : operator->()->read_from.at(op)) { + producers->insert(iter.first); + } +} + +void AccessAnalyzer::GetConsumers(const State& state, const te::Operation& op, + OperationSet* consumers) const { + OperationSet inlined_ops; + + for (const auto& stage : state->stages) { + if (stage->compute_at == kInlined) { + inlined_ops.insert(stage->op); + } + } + std::function collect; + + collect = [this, &collect, &inlined_ops, &consumers](const te::Operation& op) { + for (const auto& iter : operator->()->read_by.at(op)) { + if (inlined_ops.count(iter.first)) { + collect(iter.first); + } else { + consumers->insert(iter.first); + } + } + }; + + consumers->clear(); + collect(op); +} + +// Return whether two int arrays are elementwise-equal +bool IntArrayEqual(const Array& arr1, const Array& arr2) { + if (arr1.size() != arr2.size()) { + return false; + } + + for (size_t i = 0; i < arr1.size(); ++i) { + auto int1 = arr1[i].as(); + auto int2 = arr2[i].as(); + CHECK(int1 != nullptr); + CHECK(int2 != nullptr); + if (int1->value != int2->value) { + return false; + } + } + return true; +} + +bool AccessAnalyzer::ElementWiseMatch(const te::Operation& op, + const te::Operation& target_op) const { + te::Operation cur_op = op; + while (cur_op != target_op) { + const AccessAnalyzerNode::OperationMap > >& map = + operator->()->read_by.at(cur_op); + + if (map.size() != 1) { + return false; + } + te::Operation next_op = map.begin()->first; + + // Check condition 1: has the same output size + auto p_cur = cur_op.as(); + auto p_next = next_op.as(); + if (p_cur == nullptr || p_next == nullptr) { + return false; + } + + Array output_shape = p_cur->output_shape(0); + for (int i = 1; i < p_cur->num_outputs(); ++i) { + if (!IntArrayEqual(p_cur->output_shape(i), output_shape)) { + return false; + } + } + for (int i = 0; i < p_next->num_outputs(); ++i) { + if (!IntArrayEqual(p_next->output_shape(i), output_shape)) { + return false; + } + } + + // Check condition 2: read is elementwise + const std::vector > reads = map.begin()->second; + bool is_injective, axis_missing, axis_duplicated, same_order; + for (const auto& read : reads) { + is_injective = ::tvm::ansor::IsInjective( + next_op, read, &axis_missing, &axis_duplicated, &same_order); + if (!is_injective || axis_missing || axis_duplicated || !same_order) { + return false; + } + } + + cur_op = std::move(next_op); + } + return true; +} + +// Estimate number of float operations in an expression +class FlopEstimator: public ExprFunctor { + public: + double EstimateFlop(const Array& ops) { + double ret = 0; + for (const auto& op : ops) { + if (auto pop = op.as()) { + double num_element = AxisLengthProd(pop->axis); + if (num_element == -1) { + fail = true; + break; + } + double op_per_element = 0; + for (const auto& x : pop->body) { + op_per_element += VisitExpr(x); + } + ret += num_element * op_per_element; + } else if (op->IsInstance()) { + {} // do nothing + } else { + LOG(FATAL) << "Invalid op type " << op; + } + } + + return fail ? -1 : ret; + } + + double VisitExpr_(const ReduceNode* op) final { + uint64_t num_iter = 1; + for (const auto& x : op->axis) { + if (auto imm = x->dom->extent.as()) { + num_iter *= imm->value; + } else { + fail = true; + num_iter = -1; + } + } + double body_flop = 0; + for (size_t i = 0; i < op->combiner->result.size(); ++i) { + body_flop += VisitExpr(op->combiner->result[i]); + body_flop += VisitExpr(op->source[i]); + } + return num_iter * body_flop; + } + + double VisitExpr_(const FloatImmNode* op) final { return 0.0; } + double VisitExpr_(const IntImmNode* op) final { return 0.0; } + double VisitExpr_(const ProducerLoadNode* op) final { return 0.0; } + + double VisitExpr_(const CastNode* op) final { return VisitExpr(op->value); } + double VisitExpr_(const VarNode* op) final { return 0.0; } + + double VisitExpr_(const SelectNode* op) final { + return VisitExpr(op->condition) + std::max(VisitExpr(op->true_value), + VisitExpr(op->false_value)); + } + +#define VisitBinary(Node) \ + double VisitExpr_(const Node* op) final { \ + return 1.0 + VisitExpr(op->a) + VisitExpr(op->b); \ + } +#define VisitUnary(Node) \ + double VisitExpr_(const Node* op) final { \ + return 1.0 + VisitExpr(op->a); \ + } + + VisitBinary(AddNode); VisitBinary(SubNode); VisitBinary(MulNode) + VisitBinary(DivNode); VisitBinary(ModNode); VisitBinary(FloorDivNode) + VisitBinary(FloorModNode); VisitBinary(MaxNode); VisitBinary(MinNode); + VisitBinary(EQNode); VisitBinary(NENode); VisitBinary(LTNode); + VisitBinary(LENode); VisitBinary(GTNode); VisitBinary(GENode); + VisitBinary(AndNode); VisitBinary(OrNode); VisitUnary(NotNode); + + double VisitExpr_(const CallNode* op) final { + double ret = 0.0; + for (const auto&x : op->args) { + ret += VisitExpr(x); + } + return ret; + } + + double VisitExprDefault_(const Object* op) final { + fail = true; + return -1.0; + } + + bool fail{false}; +}; + +State ComputeDAG::GetInitState() const { + return Downcast(operator->()->init_state); +} + +ComputeDAG::ComputeDAG(Array tensors) { + auto node = make_object(); + FlopEstimator estimator; + node->tensors = std::move(tensors); + node->access_analyzer = AccessAnalyzer(node->tensors); + node->ops = Array(node->access_analyzer->ops_topo_order); + node->flop_ct = estimator.EstimateFlop(node->ops); + node->init_state = State(node->ops); + data_ = std::move(node); +} + +ComputeDAG::ComputeDAG(const std::string& workload_key) { + Array tens; + // Call python function to decode the workload_key and get the I/O tensors + if (const auto* f = runtime::Registry::Get("ansor.workload_key_to_tensors")) { + tens = (*f)(workload_key); + } else { + LOG(FATAL) << "ansor.workload_key_to_tensors is not registered"; + } + + auto node = make_object(); + FlopEstimator estimator; + node->tensors = std::move(tens); + node->access_analyzer = AccessAnalyzer(node->tensors); + node->ops = Array(node->access_analyzer->ops_topo_order); + node->flop_ct = estimator.EstimateFlop(node->ops); + node->init_state = State(node->ops); + data_ = std::move(node); +} + +std::string BaseName(const std::string& str) { + return str.substr(0, str.rfind("_")); +} + +void UpdateStageAxis(const te::Stage& stage, StageToAxesMap *stage_to_axes) { + if (auto pop = stage->op.as()) { + std::vector& axes = (*stage_to_axes)[stage]; + axes.clear(); + for (const auto& axis : pop->axis) { + axes.push_back(axis); + } + for (const auto& axis : pop->reduce_axis) { + axes.push_back(axis); + } + } else if (stage->op->IsInstance()) { + {} // do nothing + } else { + LOG(FATAL) << "Invalid op " << stage->op; + } +} + +std::pair > ComputeDAG::ApplySteps( + const std::vector& transform_steps, + LayoutRewriteLevel layout_rewrite_level) const { + std::vector stages; + StageToAxesMap stage_to_axes; + return ReplaySteps(transform_steps, &stages, &stage_to_axes); +} + +std::string ComputeDAG::PrintStepsAsPython(const std::vector& transform_steps) const { + std::vector stages; + StageToAxesMap stage_to_axes; + Array ops; + for (const auto& op : operator->()->ops) { + if (!op->IsInstance()) { + ops.push_back(op); + } + } + te::Schedule schedule = te::create_schedule({ops.back()}); + + // init axes + for (const auto& x : operator->()->ops) { + const te::Stage& stage = schedule.operator[](x); + stages.push_back(stage); + UpdateStageAxis(stage, &stage_to_axes); + } + + std::stringstream ss; + + for (const auto& stage : stages) { + if (stage->op->IsInstance()) { + for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) { + ss << stage->leaf_iter_vars[i]->var->name_hint; + if (i != stage->leaf_iter_vars.size() - 1) { + ss << ", "; + } + } + ss << " = " << "tuple(" << stage->op->name << ".op.axis)" + << " + " << "tuple(" << stage->op->name << ".op.reduce_axis)\n"; + } + } + + for (const auto& step : transform_steps) { + ss << step->PrintAsPythonAPI(&stages, &stage_to_axes, &schedule, + transform_steps); + } + + return ss.str(); +} + +State ComputeDAG::ReplayAndInferBound( + const std::vector& transform_steps) const { + State ret_state = GetInitState(); + StateNode* pstate = ret_state.CopyOnWrite(); + pstate->transform_steps = transform_steps; + ret_state.DoSteps(transform_steps, *this); + + InferBoundCommon(pstate); + + return ret_state; +} + +State ComputeDAG::InferBound(const State& state) const { + State ret_state = state; + StateNode* pstate = ret_state.CopyOnWrite(); + + InferBoundCommon(pstate); + + return ret_state; +} + +void ComputeDAG::InferBound(std::vector* states) const { + std::vector out_states(states->size(), State()); + + auto worker_func = [&states, &out_states, this](int idx) { + try { + out_states[idx] = this->InferBound((*states)[idx]); + } catch (dmlc::Error &e) { + LOG(WARNING) << "InferBound fails on the state:\n" << (*states)[idx] + << "\n" << e.what() << std::endl; + } + }; + + // Lower states in parallel + ThreadPool& pool = ThreadPool::Global(); + pool.BeginBatch(states->size()); + for (size_t i = 0; i < states->size(); ++i) { + pool.Enqueue(worker_func, i); + } + pool.WaitBatch(); + + *states = std::move(out_states); +} + +void ComputeDAG::ReplayAndGetDAG(const std::vector &transform_steps, + ComputeDAG *task_dag) const { + std::vector stages; + StageToAxesMap stage_to_axes; + te::Schedule sch; + Array old_tensors; + + std::tie(sch, old_tensors) = ReplaySteps(transform_steps, &stages, + &stage_to_axes); + + Array new_tensors; + for (auto stage : sch->stages) { + if (stage->op->IsInstance() || + stage->is_output) { + for (auto i = 0; i < stage->op->num_outputs(); ++i) { + new_tensors.push_back(stage->op.output(i)); + } + } + } + + *task_dag = ComputeDAG(new_tensors); +} + + +void ComputeDAG::InferBoundCommon(StateNode* pstate) const { + std::vector stages; + StageToAxesMap stage_to_axes; + te::Schedule sch; + Array tensors; + Map bounds; + + std::tie(sch, tensors) = ReplaySteps(pstate->transform_steps, &stages, + &stage_to_axes); + sch = sch.normalize(); + bounds = te::InferBound(sch); + + for (size_t i = 0; i < pstate->stages.size(); ++i) { + const Stage& stage = pstate->stages[i]; + + if (stage->compute_at == kInlined) { + continue; + } + + std::vector new_iters; + new_iters.reserve(stage->iters.size()); + for (size_t j = 0; j < stage->iters.size(); ++j) { + const Iterator& iter = stage->iters[j]; + const IterVar& axis = stage_to_axes.at(stages[i])[j]; + + auto find_res = bounds.find(axis); + if (find_res != bounds.end()) { + new_iters.push_back(Iterator(iter->name, (*find_res).second, + iter->iter_type, iter->annotation, + &iter->ori_iters, iter->attr)); + } else { + LOG(FATAL) << "Infer bound fails"; + } + } + + pstate->stages[i] = Stage(stage->op, stage->op_type, std::move(new_iters), + stage->compute_at, stage->attrs); + } +} + +std::pair > ComputeDAG::ReplaySteps( + const std::vector &transform_steps, + std::vector *stages, + StageToAxesMap *stage_to_axes) const { + std::vector ops; + for (const auto& op : operator->()->ops) { + if (!op->IsInstance()) { + ops.push_back(op); + } + } + + te::Schedule schedule = te::create_schedule({ops.back()}); + + // init axes + stages->reserve(operator->()->ops.size()); + for (const auto& x : operator->()->ops) { + const te::Stage& stage = schedule.operator[](x); + stages->push_back(stage); + UpdateStageAxis(stage, stage_to_axes); + } + + // Use complete rate for the study in the paper + const char* complete_rate_str = getenv("ANSOR_PROGRAM_COMPLETE_RATE"); + double complete_rate = -1.0; + if (complete_rate_str) { + complete_rate = std::stod(complete_rate_str); + } + size_t ct = 0; + + // replay history + for (const auto& step : transform_steps) { + if (complete_rate >= 0 && ct++ > transform_steps.size() * complete_rate) { + break; + } + + if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes); + } else { + LOG(FATAL) << "Invalid Step"; + } + } + + return std::make_pair(schedule, operator->()->tensors); +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter *p) { + auto* node = static_cast(ref.get()); + std::stringstream ss; + + for (const auto& op : node->ops) { + if (op->IsInstance()) { + ss << op->name << " = PLACEHOLDER " << op.output(0)->shape << "\n"; + } else if (auto pop = op.as()) { + for (size_t k = 0; k < pop->body.size(); ++k) { + ss << op->name << "("; + for (size_t i = 0; i < pop->axis.size(); i++) { + ss << pop->axis[i]->var->name_hint; + if (i != pop->axis.size() - 1) { + ss << ", "; + } + } + ss << ")"; + if (pop->body.size() > 1) { + ss << ".v" << k; + } + if (auto preduce = pop->body[k].as()) { + CHECK_LT(k, preduce->combiner->result.size()); + PrimExpr combiner = preduce->combiner->result[k]; + if (combiner->IsInstance()) { + ss << " += " << preduce->source[0] << "\n"; + } else if (combiner->IsInstance()) { + ss << " max= " << preduce->source[0] << "\n"; + } else if (combiner->IsInstance()) { + ss << " min= " << preduce->source[0] << "\n"; + } else if (combiner->IsInstance()) { + const auto& select = combiner.as(); + ss << " select(" << select->condition << ", " << select->true_value + << ", " << select->false_value << ")= " << '(' + << preduce->source[0] << ',' << preduce->source[1] << ")\n"; + } else { + LOG(FATAL) << "Unsupported reduction operator" << combiner; + } + } else { + ss << " = " << pop->body[k] << "\n"; + } + } + } else { + LOG(FATAL) << "Invalid op"; + } + } + + p->stream << ss.str(); +}); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter *p) { + auto* node = static_cast(ref.get()); + for (const auto& op : node->ops_topo_order) { + p->stream << op << std::endl; + p->stream << "is_injective:\t" << node->is_injective.at(op) << "\t\t"; + p->stream << "needs_multi_level_tiling:\t" + << node->needs_multi_level_tiling.at(op) << std::endl; + p->stream << "is_strict_inlinable:\t" << node->is_strict_inlineable.at(op) + << "\t"; + p->stream << "is_output:\t" << node->is_output.at(op) << std::endl; + p->stream << "Read from:\t"; + for (const auto& pair : node->read_from.at(op)) { + for (const auto& index : pair.second) { + p->stream << pair.first->name << Array(index) << ", "; + } + } + p->stream << "\n"; + p->stream << "Read by:\t"; + for (const auto& pair : node->read_by.at(op)) { + for (const auto& index : pair.second) { + p->stream << pair.first->name << Array(index) << ", "; + } + } + p->stream << "\n"; + p->stream << "==================================================\n"; + } + + AccessAnalyzer ana = GetRef(node); + + p->stream << "ElementwiseMatch: \n"; + for (size_t i = 0; i < node->ops_topo_order.size(); ++i) { + for (size_t j = 0; j < node->ops_topo_order.size(); ++j) { + if (i == j) { continue; } + if (ana.ElementWiseMatch(node->ops_topo_order[i], + node->ops_topo_order[j])) { + p->stream << node->ops_topo_order[i]->name << " -> " + << node->ops_topo_order[j]->name << "\n"; + } + } + } +}); + +TVM_REGISTER_GLOBAL("ansor.ComputeDAG") +.set_body_typed([](Array tensors) { + return ComputeDAG(tensors); +}); + +TVM_REGISTER_GLOBAL("ansor.ComputeDAGGetInitState") +.set_body_method(&ComputeDAG::GetInitState); + +TVM_REGISTER_GLOBAL("ansor.ComputeDAGApplyStepsFromState") +.set_body([](TVMArgs args, TVMRetValue *ret) { + ComputeDAG dag = args[0]; + State state = args[1]; + LayoutRewriteLevel layout_rewrite_level = kNoRewrite; + if (args.size() >= 3) { + layout_rewrite_level = LayoutRewriteLevel(static_cast((args[2]))); + } + + te::Schedule sch; + Array return_tensors; + std::tie(sch, return_tensors) = dag.ApplySteps(state->transform_steps, layout_rewrite_level); + *ret = Array{sch, return_tensors}; +}); + +TVM_REGISTER_GLOBAL("ansor.ComputeDAGPrintPythonCodeFromState") +.set_body_typed([](const ComputeDAG& dag, const State& state) { + return dag.PrintStepsAsPython(state->transform_steps); +}); + +TVM_REGISTER_GLOBAL("ansor.ComputeDAGInferBoundFromState") +.set_body_typed([](const ComputeDAG& dag, const State& state) { + return dag.ReplayAndInferBound(state->transform_steps); +}); + +} // namespace ansor +} // namespace tvm diff --git a/src/ansor/compute_dag.h b/src/ansor/compute_dag.h new file mode 100644 index 000000000000..2f1330d612dd --- /dev/null +++ b/src/ansor/compute_dag.h @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ansor/compute_dag.h + * \brief Compute declaration graph and its related analysis tools + */ + +#ifndef TVM_ANSOR_COMPUTE_DAG_H_ +#define TVM_ANSOR_COMPUTE_DAG_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "utils.h" + +namespace tvm { +namespace ansor { + +class StateNode; class State; class Step; + +/*! \brief Read/Write access static analysis result */ +class AccessAnalyzerNode : public Object { + public: + template + using OperationMap = std::unordered_map; + + OperationMap > > > read_from; + OperationMap > > > read_by; + OperationMap is_injective; + OperationMap is_strict_inlineable; + OperationMap needs_multi_level_tiling; + OperationMap is_output; + std::vector ops_topo_order; + + static constexpr const char* _type_key = "ansor.AccessAnalyzer"; + TVM_DECLARE_FINAL_OBJECT_INFO(AccessAnalyzerNode, Object); +}; + +/*! + * \brief Managed reference to AccessAnalyzerNode. + * \sa AccessAnalyzerNode + */ +class AccessAnalyzer : public ObjectRef { + public: + explicit AccessAnalyzer(const Array& tensors); + // read/write access analysis + bool NeedsMultiLevelTiling(const te::Operation& op) const; + bool IsInjective(const te::Operation& op) const; + bool IsStrictInlineable(const te::Operation& op) const; + bool IsOutput(const te::Operation& op) const; + + // Get all producers of an op + void GetProducers(const State& state, const te::Operation& op, + std::unordered_set* producers) const; + + // Get all consumers of an op. This func deals with inlined op correctly. + void GetConsumers(const State& state, const te::Operation& op, + std::unordered_set* consumers) const; + + // Check whether two ops are elementwise matched + // (e.g. conv2d and relu are elementwise matched) + bool ElementWiseMatch(const te::Operation& op, + const te::Operation& target_op) const; + + /*! \Note The current implementation follows these (rough) definitions. + * + * Definition of data-reuse : Exists axis in (op->axis union op->reduce_axis) + * and acc in read accesses, such that axis not in acc. + * (e.g. A[i][j] = B[i] has data reuse, while A[i][j] = B[i][j] does not) + * Definition of NeedsMultiLevelTiling: Exists two acc, both of them make this op have data reuse. + * Definition of injective : For all index expressions, they are single axis variable + * plus an optional const shift. + * (e.g. A[i][j] = B[i][j], A[i][j] = B[i+1][j] are injective, while A[i][j] = B[i*j] is not) + * Definition of strict-inlineable : All read accesses are elementwise, and no branch in the body + * (e.g. A[i][j] = B[i][j] + C[i][j] is strict-inlineable, + * while A[i][j] = tvm_if_then_else(B[i][j] > 0, C[i][j], 0) is not + */ + TVM_DEFINE_OBJECT_REF_METHODS(AccessAnalyzer, ObjectRef, AccessAnalyzerNode); +}; + +typedef std::unordered_map, ObjectHash, ObjectEqual> + StageToAxesMap; + +// Update StageToAxes Map during replay +void UpdateStageAxis(const tvm::te::Stage& stage, StageToAxesMap *stage_to_axes); + + +/*! \brief Computation declaration graph */ +class ComputeDAGNode : public Object { + public: + Array tensors; // Input and output tensors + Array ops; // All related operations in topo order + double flop_ct; // Number of float operations + AccessAnalyzer access_analyzer; // Read/Write accesss static analyzer + ObjectRef init_state; // The initial state + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("tensors", &tensors); + v->Visit("ops", &ops); + v->Visit("flop_ct", &flop_ct); + v->Visit("access_analyzer", &access_analyzer); + } + + static constexpr const char* _type_key = "ansor.ComputeDAG"; + TVM_DECLARE_FINAL_OBJECT_INFO(ComputeDAGNode, Object); +}; + +enum LayoutRewriteLevel { + kNoRewrite = 0, // No layout rewrite + kPlaceholderRewrite = 1, // Only rewrite layout of placeholder in the compute dag + kComputeRewrite = 2, // Only rewrite compute body for new layout in the compute dag + kBothRewrite = 3, // Rewrite both placeholder and compute body in the compute dag +}; + +/*! + * \brief Managed reference to ComputeDAGNode. + * \sa ComputeDAGNode + */ +class ComputeDAG: public ObjectRef { + public: + explicit ComputeDAG(Array tensors); + explicit ComputeDAG(const std::string& workload_key); + + // Apply transform steps to the init state of this DAG, and get the equivalent tvm::schedule. + // The return values can be used as arguments to tvm.build or tvm.lower + std::pair > ApplySteps( + const std::vector& transform_steps, + LayoutRewriteLevel layout_rewrite_level = kNoRewrite) const; + + // Print transform steps as equivalent python schedule API + std::string PrintStepsAsPython(const std::vector& steps) const; + + // Replay the transform steps and call ir_pass::InferBound to fill correct bound information + State ReplayAndInferBound(const std::vector& transform_steps) const; + + // Fill the correct bound information for a given state by calling ir_pass::InferBound + State InferBound(const State& state) const; + + // Fill the correct bound information for a list of given states. + // Return the new states inplace + void InferBound(std::vector* states) const; + + // Replay the transform steps and get the new DAG + void ReplayAndGetDAG(const std::vector& steps, ComputeDAG* task_dag) const; + + // Get the init state + State GetInitState() const; + + static constexpr const char* layout_free_placeholders_key = "layout_free_placeholders"; + + TVM_DEFINE_OBJECT_REF_METHODS(ComputeDAG, ObjectRef, ComputeDAGNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeDAGNode); + + private: + // Internal common parts for replaying steps + std::pair > ReplaySteps( + const std::vector& transform_steps, std::vector* stages, + StageToAxesMap* stage_to_axes) const; + + // Internal common parts for inferring bound + void InferBoundCommon(StateNode* pstate) const; +}; + +} // namespace ansor +} // namespace tvm + +#endif // TVM_ANSOR_COMPUTE_DAG_H_ diff --git a/src/ansor/loop_state.cc b/src/ansor/loop_state.cc new file mode 100644 index 000000000000..787e4256a181 --- /dev/null +++ b/src/ansor/loop_state.cc @@ -0,0 +1,653 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ansor/loop_state.cc + * \brief An lightweight IR (intermediate representation) for loop structures. + * see ansor/loop_state.h for more explanation. + */ + +#include "loop_state.h" +#include +#include +#include "transform_step.h" +#include "utils.h" + +namespace tvm { +namespace ansor { + +TVM_REGISTER_OBJECT_TYPE(StepNode); +TVM_REGISTER_NODE_TYPE(StageNode); +TVM_REGISTER_NODE_TYPE(StateNode); +TVM_REGISTER_NODE_TYPE(IteratorNode); + +// Maker for other classes +Iterator::Iterator(std::string name, Range range, IteratorType iter_type, + IteratorAnnotation annotation, + const std::vector* ori_iters, + std::string attr) { + auto node = make_object(); + node->name = std::move(name); + node->range = std::move(range); + node->iter_type = iter_type; + node->annotation = annotation; + if (ori_iters != nullptr) { + node->ori_iters = *ori_iters; + } + node->attr = std::move(attr); + data_ = std::move(node); +} + +Stage::Stage(te::Operation op) { + auto node = make_object(); + if (op->IsInstance()) { + node->op_type = kCompute; + auto* pop = op.as(); + + for (const auto& axis : pop->axis) { + node->iters.push_back(Iterator(CleanName(axis->var->name_hint), + axis->dom, kSpace, kNone)); + } + for (const auto& axis : pop->reduce_axis) { + node->iters.push_back(Iterator(CleanName(axis->var->name_hint), + axis->dom, kReduce, kNone)); + } + } else if (op->IsInstance()) { + node->op_type = kPlaceholder; + } else { + LOG(FATAL) << "Unsupported operator type" << op->_type_key; + } + + node->compute_at = kRoot; + node->op = std::move(op); + node->attrs.auto_unroll_max_step = 0; + node->attrs.storage_offset = 0; + data_ = std::move(node); +} + +Stage::Stage(te::Operation op, StageType op_type, + const std::vector& iters, ComputeAtType compute_at, + StageAttributes attrs) { + auto node = make_object(); + node->op = std::move(op); + node->op_type = op_type; + node->iters = iters; + node->compute_at = compute_at; + node->attrs = attrs; + data_ = std::move(node); +} + +Stage::Stage(te::Operation op, StageType op_type, std::vector&& iters, + ComputeAtType compute_at, StageAttributes attrs) { + auto node = make_object(); + node->op = std::move(op); + node->op_type = op_type; + node->iters = std::move(iters); + node->compute_at = compute_at; + node->attrs = attrs; + data_ = std::move(node); +} + +State::State(const Array& ops) { + auto node = make_object(); + for (const auto& op : ops) { + node->stages.push_back(Stage(op)); + } + node->attach_map = AttachMap(make_object()); + node->complete = true; + node->aux_info = ObjectRef(); + data_ = std::move(node); +} + +State::State(const std::vector& stages, + const std::vector& transform_steps, bool complete, + ObjectRef aux_info) { + auto node = make_object(); + node->stages = stages; + node->transform_steps = transform_steps; + node->attach_map = AttachMap(make_object()); + node->complete = complete; + node->aux_info = std::move(aux_info); + data_ = std::move(node); +} + +// Schedule primitives api +void State::reorder(int stage_id, const std::vector& order) { + const Stage& stage = operator->()->stages[stage_id]; + + CHECK_EQ(order.size(), stage->iters.size()) << "The order of all iterators " + "should be specified"; + std::vector after_ids; + GetIndices(stage->iters, order, &after_ids); + ReorderStep step = ReorderStep(stage_id, after_ids); + CopyOnWrite()->transform_steps.push_back(step); + DoReorderStep(step); +} + +std::vector State::split(int stage_id, const Iterator& it, + const std::vector& lengths, + bool inner_to_outer) { + const Stage& stage = operator->()->stages[stage_id]; + + SplitStep step = + SplitStep(stage_id, GetIndex(stage->iters, it), + it->range.defined() ? it->range->extent : PrimExpr(), + lengths, inner_to_outer); + CopyOnWrite()->transform_steps.push_back(step); + return DoSplitStep(step); +} + +Iterator State::fuse(int stage_id, const std::vector& iters) { + const Stage& stage = operator->()->stages[stage_id]; + std::vector indices; + GetIndices(stage->iters, iters, &indices); + FuseStep step = FuseStep(stage_id, indices); + CopyOnWrite()->transform_steps.push_back(step); + return DoFuseStep(step); +} + +// Steps' implementations +void State::DoReorderStep(const ReorderStep& step) { + const Stage& stage = operator->()->stages[step->stage_id]; + + std::vector iters; + for (auto x : step->after_ids) { + iters.push_back(stage->iters[x]); + } + + StateNode* pstate = CopyOnWrite(); + pstate->stages[step->stage_id] = Stage( + stage->op, stage->op_type, std::move(iters), stage->compute_at, + stage->attrs); +} + +// common part for DoSplitStep, DoFollowSplitStep, and DoFollowFusedSplitStep +std::vector State::DoSplitStepCommon( + int stage_id, int iter_id, const std::vector& lengths, + bool inner_to_outer) { + const Stage& stage = operator->()->stages[stage_id]; + const Iterator& it = stage->iters[iter_id]; + size_t old_iter_size = stage->iters.size(); + + PrimExpr tosplit_min, tosplit_extent; + if (it->range.defined()) { + tosplit_min = it->range->min; + tosplit_extent = it->range->extent; + } else { + tosplit_min = tosplit_extent = PrimExpr(); + } + + std::vector outs; + for (size_t i = 0; i < lengths.size(); ++i) { + PrimExpr l; + std::string name; + if (inner_to_outer) { + l = lengths[lengths.size() - i - 1]; + name = it->name + "." + std::to_string(lengths.size() - i); + } else { + l = lengths[i]; + name = it->name + "." + std::to_string(i); + } + Iterator res; + if (l.defined() && tosplit_min.defined() && tosplit_extent.defined()) { + res = Iterator(name, Range::make_by_min_extent(tosplit_min, l), + it->iter_type, kNone); + tosplit_min = 0; + tosplit_extent = indexdiv(tosplit_extent + l - 1, l); + } else { + res = Iterator(name, Range(), it->iter_type, kNone); + tosplit_min = tosplit_extent = PrimExpr(); + } + outs.push_back(std::move(res)); + } + + Range range; + if (tosplit_min.defined() && tosplit_extent.defined()) { + range = Range::make_by_min_extent(tosplit_min, tosplit_extent); + } + if (inner_to_outer) { + outs.push_back( + Iterator(it->name + ".0", range, it->iter_type, kNone)); + std::reverse(outs.begin(), outs.end()); + } else { + outs.push_back( + Iterator(it->name + "." + std::to_string(lengths.size()), + range, it->iter_type, kNone)); + } + + std::vector new_iters; + new_iters.insert(new_iters.end(), stage->iters.begin(), + stage->iters.begin() + iter_id); + new_iters.insert(new_iters.end(), outs.begin(), outs.end()); + new_iters.insert(new_iters.end(), stage->iters.begin() + iter_id + 1, + stage->iters.end()); + + StateNode* pstate = CopyOnWrite(); + pstate->stages[stage_id] = Stage( + stage->op, stage->op_type, std::move(new_iters), stage->compute_at, + stage->attrs); + + // we have to replace the iterators in attach map, + // these two vectors keep the replacement mapping + std::vector from_iters; + std::vector to_iters; + for (size_t i = iter_id; i < old_iter_size; ++i) { + from_iters.emplace_back(stage_id, i); + to_iters.emplace_back(stage_id, i + lengths.size()); + } + pstate->attach_map.ReplaceIters(from_iters, to_iters); + return outs; +} + +std::vector State::DoSplitStep(const SplitStep& step) { + return DoSplitStepCommon(step->stage_id, step->iter_id, step->lengths, + step->inner_to_outer); +} + +Iterator State::DoFuseStep(const FuseStep& step) { + int stage_id = step->stage_id; + const Stage& stage = operator->()->stages[stage_id]; + int old_iter_size = static_cast(stage->iters.size()); + + std::string new_name; + PrimExpr new_extent = 1; + IteratorType new_iter_type = kSpecial; + + std::vector ori_iters; + for (size_t i = 0; i < step->fused_ids.size(); ++i) { + if (i > 0) { + CHECK_EQ(step->fused_ids[i], step->fused_ids[i - 1] + 1); + } + + if (i != step->fused_ids.size() - 1) { + const auto& iter_to_attached_stage = + operator->()->attach_map->iter_to_attached_stages; + if (iter_to_attached_stage.find(std::make_pair( + stage_id, step->fused_ids[i])) != iter_to_attached_stage.end()) { + LOG(FATAL) << "Invalid Fuse. Because you want to fuse iterators " + "that have been attached by some stages"; + } + } + + const Iterator& it = stage->iters[step->fused_ids[i]]; + ori_iters.push_back(it); + new_name += it->name + "@"; + + if (it->range.defined() && new_extent.defined()) { + new_extent = new_extent * it->range->extent; + } else { + new_extent = PrimExpr(); + } + + if (i == 0) { + new_iter_type = it->iter_type; + } else { + if (new_iter_type != it->iter_type) { + new_iter_type = kMixed; + } + } + } + + Range range; + if (new_extent.defined()) { + range = Range::make_by_min_extent(0, new_extent); + } + Iterator new_it = + Iterator(new_name, range, new_iter_type, kNone, &ori_iters); + std::vector new_iters; + new_iters.insert(new_iters.end(), stage->iters.begin(), + stage->iters.begin() + step->fused_ids.front()); + new_iters.push_back(new_it); + new_iters.insert(new_iters.end(), + stage->iters.begin() + step->fused_ids.back() + 1, + stage->iters.end()); + + StateNode* pstate = CopyOnWrite(); + pstate->stages[stage_id] = Stage( + stage->op, stage->op_type, std::move(new_iters), stage->compute_at, + stage->attrs); + + // we have to replace the iterators in attach map, + // these two vectors keep the replacement mapping + std::vector from_iters; + std::vector to_iters; + const int begin_id = step->fused_ids.front(), end_id = step->fused_ids.back(); + for (int i = 0; i < old_iter_size; ++i) { + if (i <= begin_id) { + continue; + } else if (i > end_id) { // move forward + from_iters.emplace_back(stage_id, i); + to_iters.emplace_back(stage_id, i - end_id + begin_id); + } else { // move to the fused id + from_iters.emplace_back(stage_id, i); + to_iters.emplace_back(stage_id, begin_id); + } + } + pstate->attach_map.ReplaceIters(from_iters, to_iters); + return new_it; +} + +void State::DoStep(const Step& step, const ComputeDAG& dag) { + if (auto ps = step.as()) { + DoReorderStep(GetRef(ps)); + } else if (auto ps = step.as()) { + DoSplitStep(GetRef(ps)); + } else if (auto ps = step.as()) { + DoFuseStep(GetRef(ps)); + } else { + LOG(FATAL) << "Invalid step: " << step; + } +} + +void State::DoSteps(const std::vector& steps, const ComputeDAG& dag) { + // Use complete rate for the study in the paper + const char* complete_rate_str = getenv("ANSOR_PROGRAM_COMPLETE_RATE"); + double complete_rate = -1.0; + if (complete_rate_str) { + complete_rate = std::stod(complete_rate_str); + } + size_t ct = 0; + + for (const auto& step : steps) { + if (complete_rate >= 0 && ct++ > steps.size() * complete_rate) { + break; + } + DoStep(step, dag); + } +} + +void PrintStage(std::ostream* os, int stage_id, const StateNode* state, + size_t base_indent, bool delete_trivial_loop) { + const Stage& stage = state->stages[stage_id]; + + if (stage->attrs.auto_unroll_max_step != 0) { + for (size_t j = 0; j < base_indent; ++j) { + *os << " "; + } + *os << stage->op->name + << " auto_unroll: " << stage->attrs.auto_unroll_max_step << "\n"; + } + if (stage->attrs.storage_offset != 0) { + for (size_t j = 0; j < base_indent; ++j) { + *os << " "; + } + *os << stage->op->name + << " storage_offset: " << stage->attrs.storage_offset << "\n"; + } + + size_t indent = 0; + for (size_t i = 0; i < stage->iters.size(); ++i) { + const Iterator& iter = stage->iters[i]; + + if (!(delete_trivial_loop && iter->range.defined() && + is_one(iter->range->extent))) { + for (size_t j = 0; j < base_indent + indent; ++j) { + *os << " "; + } + switch (iter->annotation) { + case kNone: + *os << "for "; + break; + case kUnroll: + *os << "unroll "; + break; + case kParallel: + *os << "parallel "; + break; + case kVectorize: + *os << "vectorize "; + break; + case kVThread: + *os << "vthread "; + break; + case kBlockX: + *os << "gpu.blockIdx.x "; + break; + case kBlockY: + *os << "gpu.blockIdx.y "; + break; + case kThreadX: + *os << "gpu.threadIdx.x "; + break; + case kThreadY: + *os << "gpu.threadIdx.y "; + break; + case kTensorized: + *os << "tensorize "; + break; + default: + LOG(FATAL) << "Invalid Annotation " << iter->annotation; break; + } + if (iter->range.defined()) { + *os << iter->name << " (" << iter->range->min << "," + << iter->range->extent << ")"; + } else { + *os << iter->name << " (None)"; + } + if (!iter->attr.empty()) { + *os << " " << iter->attr; + } + *os << "\n"; + + indent += 2; + } + + if (state != nullptr) { + AttachMap::IterKey iter_key(stage_id, i); + auto pair = state->attach_map->iter_to_attached_stages.find(iter_key); + if (pair != state->attach_map->iter_to_attached_stages.end()) { + for (const auto& attach_stage_id : pair->second) { + PrintStage(os, attach_stage_id, state, base_indent + indent, + delete_trivial_loop); + } + } + } + } + + for (size_t j = 0; j < base_indent + indent; ++j) { + *os << " "; + } + *os << stage->op->name << " = ...\n"; +} + +void PrintState(std::ostream* os, const StateNode* node, + bool delete_trivial_loop) { + // Gather placeholders + std::vector placeholders; + for (const auto& stage : node->stages) { + if (stage->op_type == kPlaceholder) { + placeholders.push_back(stage->op->name); + } + } + + *os << "Placeholder: "; + for (size_t i = 0; i < placeholders.size(); ++i) { + *os << placeholders[i]; + if (i != placeholders.size() - 1) { + *os << ", "; + } + } + *os << "\n"; + + // Print all stages + for (size_t i = 0; i < node->stages.size(); ++i) { + const Stage& stage = node->stages[i]; + if (stage->op_type == kPlaceholder) { + continue; + } else if (stage->op_type == kCompute) { + if (stage->compute_at == kRoot) { + PrintStage(os, i, node, 0, delete_trivial_loop); + } + } else { + LOG(FATAL) << "Invalid op type"; + } + } +} + +std::string State::ToStr(bool delete_trivial_loop) const { + std::ostringstream os; + PrintState(&os, operator->(), delete_trivial_loop); + return os.str(); +} + +void AttachMap::SetComputeAtIter(int stage_id, int target_stage_id, + int target_iter_id) { + AttachMapNode* pnode = CopyOnWrite(); + + // delete the current entry of stage + DeleteStageEntry(pnode, stage_id); + + // store the new relation + IterKey iter_key(target_stage_id, target_iter_id); + pnode->stage_to_attach_iter[stage_id] = + std::make_pair(target_stage_id, target_iter_id); + pnode->iter_to_attached_stages[iter_key].push_back(stage_id); +} + +void AttachMap::DeleteStage(int stage_id) { + AttachMapNode* pnode = CopyOnWrite(); + + // delete the entry of old stage + DeleteStageEntry(pnode, stage_id); +} + +void AttachMap::ReplaceIters(const std::vector& old_iters, + const std::vector& new_iters) { + AttachMapNode* pnode = CopyOnWrite(); + + CHECK_EQ(old_iters.size(), new_iters.size()); + for (size_t i = 0; i < old_iters.size(); ++i) { + auto entry = pnode->iter_to_attached_stages.find(old_iters[i]); + if (entry == pnode->iter_to_attached_stages.end()) { + continue; + } + + // replace iter in the value of `stage_to_attach_iter` + for (const auto& s : entry->second) { + pnode->stage_to_attach_iter[s] = new_iters[i]; + } + + // replace iter in the key of `iter_to_attached_stages` + std::vector attached_stages = std::move(entry->second); + pnode->iter_to_attached_stages.erase(entry); + pnode->iter_to_attached_stages[new_iters[i]] = std::move(attached_stages); + } +} + +void AttachMap::DeleteStageEntry(AttachMapNode* pnode, int stage_id) { + auto old_entry = pnode->stage_to_attach_iter.find(stage_id); + if (old_entry != pnode->stage_to_attach_iter.end()) { + // delete value in `iter_to_attached_stages` + auto entry2 = pnode->iter_to_attached_stages.find(old_entry->second); + DeleteItem(&entry2->second, stage_id); + if (entry2->second.size() == 0) { + pnode->iter_to_attached_stages.erase(entry2); + } + // delete key in `stage_to_attach_iter` + pnode->stage_to_attach_iter.erase(old_entry); + } +} + +AttachMap AttachMap::ApplyStageIdOfffset(int start_id, int offset) const { + AttachMap map = AttachMap(make_object()); + auto pmap = map.CopyOnWrite(); + for (const auto& x : operator->()->stage_to_attach_iter) { + auto key = x.first; + if (key >= start_id) { + key += offset; + } + auto value = x.second; + if (value.first >= start_id) { + value.first += offset; + } + pmap->stage_to_attach_iter.insert(std::make_pair(key, value)); + } + for (const auto& x : operator->()->iter_to_attached_stages) { + auto key = x.first; + if (key.first >= start_id) { + key.first += offset; + } + auto value = x.second; + for (auto& i : value) { + if (i >= start_id) { + i += offset; + } + } + pmap->iter_to_attached_stages.insert(std::make_pair(key, value)); + } + return map; +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + PrintState(&p->stream, node, true); +}); + +TVM_REGISTER_GLOBAL("ansor.StageGetIterators").set_body_typed([](const Stage& stage) { + return Array(stage->iters); +}); + +TVM_REGISTER_GLOBAL("ansor.StateGetStages").set_body_typed([](const State& state) { + return Array(state->stages); +}); + +TVM_REGISTER_GLOBAL("ansor.StateGetTransformStepsSize").set_body_typed([](const State& state) { + return static_cast(state->transform_steps.size()); +}); + +TVM_REGISTER_GLOBAL("ansor.StateReorder") +.set_body_typed([](State state, int stage_id, const Array& order) { + std::vector ord; + for (const auto& i : order) { + ord.push_back(i); + } + state.reorder(stage_id, ord); + return state; +}); + +TVM_REGISTER_GLOBAL("ansor.StateSplit") +.set_body_typed([](State state, int stage_id, const Iterator& it, + const Array& lengths, bool inner_to_outer) { + std::vector len; + for (const auto& i : lengths) { + len.push_back(i); + } + const auto& res = state.split(stage_id, it, len, inner_to_outer); + return Array{state, Array(res)}; +}); + +TVM_REGISTER_GLOBAL("ansor.StateFuse") +.set_body_typed([](State state, int stage_id, + const Array& iters) { + std::vector its; + for (const auto& i : iters) { + its.push_back(i); + } + const auto& res = state.fuse(stage_id, its); + return Array{state, res}; +}); + +TVM_REGISTER_GLOBAL("ansor.StateEqual") +.set_body_typed([](State state1, State state2) { + return std::equal_to()(state1, state2); +}); + +} // namespace ansor +} // namespace tvm diff --git a/src/ansor/loop_state.h b/src/ansor/loop_state.h new file mode 100644 index 000000000000..2d6c85db0247 --- /dev/null +++ b/src/ansor/loop_state.h @@ -0,0 +1,325 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ansor/loop_state.h + * \brief The definition of the "state" in search. A state consists a current loop structure + * and the transform history to reach its current loop structure. + * To enable flexible manipulation of the loop structure, we implemented a lightweight + * loop structure IR (Intermediate Representation) specifically for search. + * + * Basically this is a simplified TVM IR with schedule primitives. + * We don't use the existing TVM IR because + * 1. We want fast incremental change to the loop structures + * 2. We want serializable transformation history for replay, backtracking, and mutation. + * 3. We may create some macro schedule primitives + * + * After the search is done, we will lower this IR to TVM IR with TVM schedule primitives. + * Because we share a lot common objects during search, the transformation is + * implemented in copy on write style. All objects are immutable, which is + * similar to TVM IR. + */ + +#ifndef TVM_ANSOR_LOOP_STATE_H_ +#define TVM_ANSOR_LOOP_STATE_H_ + +#include +#include +#include +#include +#include +#include "compute_dag.h" + +namespace tvm { +namespace ansor { + +using namespace tvm::tir; + +/*! \brief The type of a stage */ +enum StageType { + kPlaceholder, // A placeholder stage + kCompute // A compute stage +}; + +/*! \brief The type of compute location */ +enum ComputeAtType { + kRoot, // compute at root + kInlined, // inlined + kIter, // compute at some iterator +}; + +/*! \brief The type of an iterator */ +enum IteratorType { + kSpace, // spatial iterator + kReduce, // reduction iterator + kMixed, // fused spatial and reduction iterator + kSpecial // special iterator (e.g. virtual root iterator) +}; + +/*! \brief The type of an iterator's annotation */ +enum IteratorAnnotation { + kNone, kUnroll, kVectorize, kParallel, + kVThread, kBlockX, kThreadX, kBlockY, kThreadY, + kTensorized +}; + +// forward declaration +class Iterator; + +/*! + * \brief A for loop iterator + * Similar to tvm::IterVar in `include/tvm/tir/expr.h` + */ +class IteratorNode : public Object { + public: + std::string name; + Range range; + IteratorType iter_type; + IteratorAnnotation annotation; + std::vector ori_iters; // The original iterators before fusion + std::string attr; // Todo(jcf94): Document this + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("name", &name); + v->Visit("range", &range); + v->Visit("attr", &attr); + } + + static constexpr const char *_type_key = "ansor.Iterator"; + TVM_DECLARE_FINAL_OBJECT_INFO(IteratorNode, Object); +}; + +/*! + * \brief Managed reference to IteratorNode. + * \sa IteratorNode + */ +class Iterator : public ObjectRef { + public: + Iterator(std::string name, Range range, IteratorType iter_type, + IteratorAnnotation annotation, + const std::vector* ori_iters = nullptr, + std::string attr = ""); + + TVM_DEFINE_OBJECT_REF_METHODS(Iterator, ObjectRef, IteratorNode); +}; + +/*! \brief Stage-level attributes */ +struct StageAttributes { + int auto_unroll_max_step; // The maximum steps for the pragma `auto_unroll_max_step` + int storage_offset; // The storage offset for the schedule primitive `storage_align` +}; + +/*! + * \brief A stage in the compute declaration + * Similar to te::Stage in `include/schedule.h` + */ +class StageNode : public Object { + public: + te::Operation op; // The operator of this stage + StageType op_type; // The type of this stage + std::vector iters; // The iterators in this stage + ComputeAtType compute_at; // The compute location of this stage + StageAttributes attrs; // Other stage-level attributes + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("op", &op); + } + + static constexpr const char *_type_key = "ansor.Stage"; + TVM_DECLARE_FINAL_OBJECT_INFO(StageNode, Object); +}; + +/*! + * \brief Managed reference to StageNode. + * \sa StageNode + */ +class Stage : public ObjectRef { + public: + explicit Stage(te::Operation op); + Stage(te::Operation op, StageType op_type, + const std::vector& iters, + ComputeAtType compute_at, StageAttributes attrs); + Stage(te::Operation op, StageType op_type, + std::vector&& iters, + ComputeAtType compute_at, StageAttributes attrs); + + TVM_DEFINE_OBJECT_REF_METHODS(Stage, ObjectRef, StageNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(StageNode); +}; + +/*! \brief stores the compute_at relation between stages + * This stores a bi-directional mapping from stages and iter: + * 1. Stage to its attached iterator 2. Iterator to the stage attached to it + * + * You can use AttachMapNode::stage_to_attach_iter and AttachMapNode::iter_to_attached_stages + * to query the relations */ +class AttachMapNode: public Object { + public: + using StageKey = int; + using IterKey = std::pair; // stage_id and iter_id + + std::unordered_map stage_to_attach_iter; + std::unordered_map> iter_to_attached_stages; + + static constexpr const char* _type_key = "ansor.AttachMap"; + TVM_DECLARE_FINAL_OBJECT_INFO(AttachMapNode, Object); +}; + +/*! + * \brief Managed reference to AttachMapNode. + * \sa AttachMapNode + */ +class AttachMap : public ObjectRef { + public: + using StageKey = int; + using IterKey = std::pair; // stage_id and iter_id + + void SetComputeAtIter(int stage_id, int target_stage_id, int target_iter_id); + void DeleteStage(int stage_id); + void ReplaceIters(const std::vector& old_iters, + const std::vector& new_iters); + AttachMap ApplyStageIdOfffset(int start_id, int offset) const; + + TVM_DEFINE_OBJECT_REF_METHODS(AttachMap, ObjectRef, AttachMapNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(AttachMapNode); + + private: + static void DeleteStageEntry(AttachMapNode* pnode, int stage_id); +}; + +/*! \brief The base class for a transformation step */ +class StepNode: public Object { + public: + int stage_id; + + // Print step as equivalent python schedule API + virtual std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const = 0; + + static constexpr const char* _type_key = "ansor.Step"; + TVM_DECLARE_BASE_OBJECT_INFO(StepNode, Object); +}; +TVM_DEFINE_MUTABLE_OBJECT_REF(Step, StepNode); + +// Step forward decelerations +class ReorderStep; class SplitStep; class FuseStep; + +/*! \brief A state in the search process. + * It consists of the current loop structure and the history steps to reach this state. */ +class StateNode: public Object { + public: + std::vector stages; // Current stages and loop structures + std::vector transform_steps; // History transformation steps + bool complete; // Indicate whether this state has unfilled tile sizes + AttachMap attach_map; // stores the compute_at relation between stages + ObjectRef aux_info; // Used to store any auxiliary info about this state + ComputeDAG task_dag; // The up-to-date ComputeDAG of this state. + // The default value is an empty NodeRef + // (means no modification to the DAG) + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("complete", &complete); + v->Visit("aux_info", &aux_info); + v->Visit("task_dag", &task_dag); + } + + static constexpr const char* _type_key = "ansor.State"; + TVM_DECLARE_FINAL_OBJECT_INFO(StateNode, Object); +}; + +/*! + * \brief Managed reference to StateNode. + * \sa StateNode + */ +class State : public ObjectRef { + public: + explicit State(const Array& ops); + State(const std::vector& stages, + const std::vector& transform_steps, bool complete, + ObjectRef aux_info); + + // Schedule primitives + void reorder(int stage_id, const std::vector& order); + std::vector split(int stage_id, const Iterator& it, + const std::vector& lengths, + bool inner_to_outer = true); + Iterator fuse(int stage_id, const std::vector& iters); + + /* Do transform steps + * Note: The following functions only change loop state but do not change transform_history. + * We separate these functions out, + * so you can call them for replay easily given history steps */ + void DoReorderStep(const ReorderStep& step); + std::vector DoSplitStep(const SplitStep& step); + Iterator DoFuseStep(const FuseStep& step); + + // General do step functions with a runtime dynamic dispatcher + void DoStep(const Step& step, const ComputeDAG& dag); + void DoSteps(const std::vector& step, const ComputeDAG& dag); + + // Print the state to a string + std::string ToStr(bool delete_trivial_loop = true) const; + + TVM_DEFINE_OBJECT_REF_METHODS(State, ObjectRef, StateNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(StateNode); + + private: + // Common function for DoSplitStep and DoFollowSplitStep + std::vector DoSplitStepCommon(int stage_id, int iter_id, + const std::vector& lengths, + bool inner_to_outer); +}; + +/*! \brief Clean the name of an iterator to make it valid in python code */ +inline std::string CleanName(const std::string& str) { + std::string ret = str; + StrReplace(&ret, ".", "_"); + StrReplace(&ret, "@", "_"); + StrReplace(&ret, "outer", "o"); + StrReplace(&ret, "inner", "i"); + return ret; +} + +} // namespace ansor +} // namespace tvm + + +// Hash and equal function for State +namespace std { + +template <> +struct hash<::tvm::ansor::State> { + std::size_t operator()(const ::tvm::ansor::State& state) const { + return std::hash()(state.ToStr()); + } +}; + +template <> +struct equal_to<::tvm::ansor::State> { + bool operator() (const ::tvm::ansor::State& lhs, + const ::tvm::ansor::State& rhs) const { + return lhs.ToStr() == rhs.ToStr(); + } +}; + +} // namespace std + +#endif // TVM_ANSOR_LOOP_STATE_H_ diff --git a/src/ansor/measure.cc b/src/ansor/measure.cc new file mode 100644 index 000000000000..c50191813b2e --- /dev/null +++ b/src/ansor/measure.cc @@ -0,0 +1,357 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ansor/measure.cc + * \brief Distributed measurement infrastructure to measure the runtime costs of tensor programs + */ + +#include "measure.h" +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace ansor { + +TVM_REGISTER_NODE_TYPE(MeasureInputNode); +TVM_REGISTER_NODE_TYPE(BuildResultNode); +TVM_REGISTER_NODE_TYPE(MeasureResultNode); +TVM_REGISTER_OBJECT_TYPE(MeasureCallbackNode); +TVM_REGISTER_OBJECT_TYPE(RunnerNode); +TVM_REGISTER_OBJECT_TYPE(BuilderNode); +TVM_REGISTER_OBJECT_TYPE(LocalBuilderNode); +TVM_REGISTER_OBJECT_TYPE(LocalRunnerNode); + +const char* ErrorNoToStr[] = { + "NoError", + "InstantiationError", + "CompileHostError", + "CompileDeviceError", + "RuntimeDeviceError", + "WrongAnswerError", + "BuildTimeoutError", + "RunTimeoutError", + "UnknownError", +}; + +// Measure input and result +MeasureInput::MeasureInput(SearchTask task, State state) { + auto node = make_object(); + node->task = std::move(task); + node->state = std::move(state); + data_ = std::move(node); +} + +MeasureInput MeasureInputNode::copy() const { + auto node = make_object(); + node->task = task; + node->state = state; + return MeasureInput(node); +} + +BuildResult::BuildResult(std::string filename, Array args, + int error_no, std::string error_msg, + double time_cost) { + auto node = make_object(); + node->filename = std::move(filename); + node->args = std::move(args); + node->error_no = error_no; + node->error_msg = std::move(error_msg); + node->time_cost = time_cost; + data_ = std::move(node); +} + +MeasureResult::MeasureResult(Array costs, int error_no, + std::string error_msg, double all_cost, + double timestamp) { + auto node = make_object(); + node->costs = std::move(costs); + node->error_no = error_no; + node->error_msg = std::move(error_msg); + node->all_cost = all_cost; + node->timestamp = timestamp; + data_ = std::move(node); +} + +MeasureResult MeasureResultNode::copy() const { + auto node = make_object(); + node->costs = costs; + node->error_no = error_no; + node->error_msg = error_msg; + node->all_cost = all_cost; + node->timestamp = timestamp; + return MeasureResult(node); +} + +// LocalBuilder +LocalBuilder::LocalBuilder(int timeout, int n_parallel, + const std::string& build_func) { + auto node = make_object(); + node->timeout = timeout; + node->n_parallel = n_parallel; + node->build_func = build_func; + data_ = std::move(node); +} + +Array LocalBuilderNode::Build(const Array& inputs, + int verbose) { + if (const auto* f = runtime::Registry::Get("ansor.local_builder.build")) { + Array results = + (*f)(inputs, timeout, n_parallel, build_func, verbose); + return results; + } else { + LOG(FATAL) << "ansor.local_builder.build is not registered"; + } + return Array(); +} + +// Local Runner +LocalRunner::LocalRunner(int timeout, int number, int repeat, + int min_repeat_ms, double cooldown_interval) { + ObjectPtr node = make_object(); + node->timeout = timeout; + node->number = number; + node->repeat = repeat; + node->min_repeat_ms = min_repeat_ms; + node->cooldown_interval = cooldown_interval; + data_ = std::move(node); +} + +Array LocalRunnerNode::Run( + const Array& inputs, const Array& build_results, + int verbose) { + if (const auto* f = runtime::Registry::Get("ansor.local_runner.run")) { + Array results = + (*f)(inputs, build_results, timeout, number, repeat, min_repeat_ms, + cooldown_interval, verbose); + return results; + } else { + LOG(FATAL) << "ansor.local_runner.run is not registered"; + } + return Array(); +} + +// Program Measurer +ProgramMeasurer::ProgramMeasurer(Builder builder, Runner runner, + Array callbacks, int verbose, + int max_continous_error) { + auto node = make_object(); + node->builder = std::move(builder); + node->runner = std::move(runner); + node->callbacks = std::move(callbacks); + node->verbose = verbose; + node->max_continous_error = max_continous_error < 0 ? + ProgramMeasurerNode::DEFAULT_MAX_CONTINOUS_ERROR : max_continous_error; + data_ = std::move(node); +} + +void ProgramMeasurerNode::Reset() { + ct = error_ct = 0; + best_flops.clear(); + best_ct.clear(); + best_state.clear(); +} + +void ProgramMeasurerNode::Measure(const SearchTask& task, + const SearchPolicy& policy, + const std::vector& inputs, + std::vector* results, + int batch_size) { + results->clear(); + results->reserve(inputs.size()); + + if (batch_size == -1) { + // set default batch size + batch_size = builder->n_parallel * 2; + } + + StdCout(verbose) << "Get " << inputs.size() + << " programs for measure. (This may take a while)" + << std::endl; + + for (size_t i = 0; i < inputs.size(); i += batch_size) { + std::vector input_batch( + inputs.begin() + i, + inputs.begin() + std::min(i + batch_size, inputs.size())); + std::vector result_batch; + + // build and run + SilentMeasure(task, input_batch, &result_batch); + + // update current best state according to the new measure result + for (size_t j = 0; j < input_batch.size(); ++j) { + double flops; + if (result_batch[j]->error_no == 0) { + flops = + task->compute_dag->flop_ct / FloatArrayMean(result_batch[j]->costs); + error_ct = 0; + } else { + flops = 0.0; + error_ct++; + } + + const std::string& workload_key = input_batch[j]->task->workload_key; + if (flops > best_flops[workload_key]) { + best_flops[workload_key] = flops; + best_state[workload_key] = input_batch[j]->state; + best_ct[workload_key] = ct; + } + + ct++; + if (verbose >= 1) { + std::cout << std::fixed << std::setprecision(2); + std::cout << "===============================================\n"; + std::cout << "No: " << ct << "\tGFLOPS: " << flops / 1e9 << " / " + << best_flops[workload_key] / 1e9 + << "\tresults: " << result_batch[j] << "\n"; + std::cout << "===============================================\n"; + std::cout << input_batch[j]->state << "\n"; + } + } + + // Call callback functions + for (const auto& callback : callbacks) { + callback->callback(policy, input_batch, result_batch); + } + + // Store result batch + for (auto& res : result_batch) { + results->push_back(res); + } + + if (error_ct > max_continous_error) { + LOG(FATAL) << "Too many errors happened during tuning"; + } + } +} + +void ProgramMeasurerNode::SilentMeasure(const SearchTask& task, + const std::vector& inputs, + std::vector* results) { + // Close the thread pool to avoid the conflits with python environment + ThreadPool::Global().Abort(); + + results->clear(); + results->reserve(inputs.size()); + Array input_batch(inputs.begin(), inputs.end()); + + // Call builder and runner + Array build_res_batch = builder->Build(input_batch, verbose); + Array result_batch = + runner->Run(input_batch, build_res_batch, verbose); + + // Store result batch + for (auto& res : result_batch) { + results->push_back(res); + } +} + +// Printing functions +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + p->stream << "MeasureInput()"; +}); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + if (node->error_no == kNoError) { + p->stream << "MeasureResult(cost:["; + auto old_config = p->stream.precision(4); + for (size_t i = 0; i < node->costs.size(); ++i) { + auto pf = node->costs[i].as(); + CHECK(pf != nullptr); + p->stream << pf->value; + if (i != node->costs.size() - 1) { + p->stream << ","; + } + } + p->stream.precision(old_config); + p->stream << "], "; + p->stream << "error_no:" << 0 << ", " + << "all_cost:" << node->all_cost << ", " + << "Tstamp:" << node->timestamp << ")"; + } else { + p->stream << "MeasureResult(" + << "error_type:" << ErrorNoToStr[node->error_no] << ", " + << "error_msg:" << node->error_msg << ", " + << "all_cost:" << node->all_cost << ", " + << "Tstamp:" << node->timestamp << ")"; + } +}); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "BuildResult(" << node->filename << ", " << node->error_no + << ", " << node->time_cost << ")"; +}); + +TVM_REGISTER_GLOBAL("ansor.MeasureInput").set_body_typed([](SearchTask task, State state) { + return MeasureInput(task, state); +}); + +TVM_REGISTER_GLOBAL("ansor.BuildResult") +.set_body_typed([](std::string filename, Array args, + int error_no, std::string error_msg, double time_cost) { + return BuildResult(filename, args, error_no, error_msg, time_cost); +}); + +TVM_REGISTER_GLOBAL("ansor.MeasureResult") +.set_body_typed([](Array costs, int error_no, std::string error_msg, + double all_cost, double timestamp) { + return MeasureResult(costs, error_no, error_msg, all_cost, timestamp); +}); + +TVM_REGISTER_GLOBAL("ansor.BuilderBuild") +.set_body_typed([](const Builder& builder, const Array& inputs, int verbose) { + return builder->Build(inputs, verbose); +}); + +TVM_REGISTER_GLOBAL("ansor.RunnerRun") +.set_body_typed([](const Runner& runner, const Array& inputs, + const Array& build_results, int verbose) { + return runner->Run(inputs, build_results, verbose); +}); + +TVM_REGISTER_GLOBAL("ansor.LocalBuilder") +.set_body_typed([](int timeout, int n_parallel, const std::string& build_func) { + return LocalBuilder(timeout, n_parallel, build_func); +}); + +TVM_REGISTER_GLOBAL("ansor.LocalRunner") +.set_body_typed([](int timeout, int number, int repeat, + int min_repeat_ms, double cooldown_interval) { + return LocalRunner(timeout, number, repeat, min_repeat_ms, cooldown_interval); +}); + +TVM_REGISTER_GLOBAL("ansor.ProgramMeasurer") +.set_body_typed([](Builder builder, Runner runner, + Array callbacks, int verbose, + int max_continous_error = -1) { + return ProgramMeasurer(builder, runner, callbacks, verbose, + max_continous_error); +}); + +} // namespace ansor +} // namespace tvm diff --git a/src/ansor/measure.h b/src/ansor/measure.h new file mode 100644 index 000000000000..630365512eb6 --- /dev/null +++ b/src/ansor/measure.h @@ -0,0 +1,306 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ansor/measure.h + * \brief Distributed measurement infrastructure to measure the runtime costs of tensor programs + */ + +#ifndef TVM_ANSOR_MEASURE_H_ +#define TVM_ANSOR_MEASURE_H_ + +#include +#include +#include +#include +#include "search_task.h" +#include "loop_state.h" + +namespace tvm { +namespace ansor { + +class SearchPolicy; +class MeasureInput; class BuildResult; class MeasureResult; +class Builder; class Runner; class MeasureCallback; class ProgramMeasurer; + +/* \brief The error code of one measurement */ +enum MeasureErrorNO { + kNoError = 0, // No error + kInstantiationError = 1, // Errors happen when apply transform steps from init state + kCompileHostError = 2, // Errors happen when compiling code on host (when build module) + kCompileDeviceError = 3, // Errors happen when compiling code on device (when load module) + kRuntimeDeviceError = 4, // Errors happen when run program on device + kWrongAnswerError = 5, // Answer is wrong when compared to a reference output + kBuildTimeoutError = 6, // Timeout during compilation + kRunTimeoutError = 7, // Timeout during run + kUnknonwError = 8, // Unknown error +}; +extern const char *ErrorNoToStr[]; + +// Inputs and results of one measurement + +/*! \brief Store the input of a measurement */ +class MeasureInputNode: public Object { + public: + SearchTask task; // The search task + State state; // The program state to be measured + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("task", &task); + v->Visit("state", &state); + } + + MeasureInput copy() const; // Do deep copy + + static constexpr const char* _type_key = "ansor.MeasureInput"; + TVM_DECLARE_FINAL_OBJECT_INFO(MeasureInputNode, Object); +}; + +/*! + * \brief Managed reference to MeasureInputNode. + * \sa MeasureInputNode + */ +class MeasureInput : public ObjectRef { + public: + MeasureInput(SearchTask task, State state); + + TVM_DEFINE_OBJECT_REF_METHODS(MeasureInput, ObjectRef, MeasureInputNode); +}; + +/*! \brief Store the input of a build */ +class BuildResultNode: public Object { + public: + std::string filename; // The filename of built binary file + Array args; // The arguments + int error_no; // The error code (see MeasureErrorNO). + // 0 means no error. + std::string error_msg; // The error message if there is any error + double time_cost; // The time cost of build + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("filename", &filename); + v->Visit("args", &args); + v->Visit("error_no", &error_no); + v->Visit("error_msg", &error_msg); + v->Visit("time_cost", &time_cost); + } + + static constexpr const char* _type_key = "ansor.BuildResult"; + TVM_DECLARE_FINAL_OBJECT_INFO(BuildResultNode, Object); +}; + +/*! + * \brief Managed reference to BuildResultNode. + * \sa BuildResultNode + */ +class BuildResult : public ObjectRef { + public: + BuildResult(std::string filename, Array args, + int error_no, std::string error_msg, double time_cost); + TVM_DEFINE_OBJECT_REF_METHODS(BuildResult, ObjectRef, BuildResultNode); +}; + +/*! \brief Store the results of a measurement */ +class MeasureResultNode: public Object { + public: + Array costs; // The time costs of execution + int error_no; // The error code (see MeasureErrorNO). + // 0 means no error. + std::string error_msg; // The error message if there is any error + double all_cost; // The time cost of build and run + double timestamp; // The time stamps of this measurement + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("costs", &costs); + v->Visit("error_no", &error_no); + v->Visit("error_msg", &error_msg); + v->Visit("all_cost", &all_cost); + v->Visit("timestamp", ×tamp); + } + + MeasureResult copy() const; // Do deep copy + + static constexpr const char* _type_key = "ansor.MeasureResult"; + TVM_DECLARE_FINAL_OBJECT_INFO(MeasureResultNode, Object); +}; + +/*! + * \brief Managed reference to MeasureResultNode. + * \sa MeasureResultNode + */ +class MeasureResult : public ObjectRef { + public: + MeasureResult(Array costs, int error_no, std::string error_msg, + double all_cost, double timestamp); + + TVM_DEFINE_OBJECT_REF_METHODS(MeasureResult, ObjectRef, MeasureResultNode); +}; + +/*! \brief Bass class of measurement callbacks */ +class MeasureCallbackNode: public Object { + public: + /*! \biref Callback function that will be called on measurement input/result pairs + * after measurement */ + virtual void callback(const SearchPolicy& policy, + const Array& inputs, + const Array& results) = 0; + static constexpr const char *_type_key = "ansor.MeasureCallback"; + TVM_DECLARE_BASE_OBJECT_INFO(MeasureCallbackNode, Object); +}; +TVM_DEFINE_MUTABLE_OBJECT_REF(MeasureCallback, MeasureCallbackNode); + +// Base class for builder and runner +/*! \brief Builder that builds the programs */ +class BuilderNode: public Object { + public: + int n_parallel; // The number of tasks to run in parallel + int timeout; // Timeout of a build + + /*! \biref Build programs and return results */ + virtual Array Build(const Array& inputs, int verbose) = 0; + + static constexpr const char* _type_key = "ansor.Builder"; + TVM_DECLARE_BASE_OBJECT_INFO(BuilderNode, Object); +}; +TVM_DEFINE_MUTABLE_OBJECT_REF(Builder, BuilderNode); + +/*! \brief Runner that runs the built programs and measure the time cost */ +class RunnerNode: public Object { + public: + int timeout; // Timeout of a run + + /*! \biref Run measurement and return results */ + virtual Array Run(const Array& inputs, + const Array& build_results, + int verbose) = 0; + + static constexpr const char* _type_key = "ansor.Runner"; + TVM_DECLARE_BASE_OBJECT_INFO(RunnerNode, Object); +}; +TVM_DEFINE_MUTABLE_OBJECT_REF(Runner, RunnerNode); + + +// Implementation of various builders and runners +/*! \brief LocalBuilder use local CPU cores to build programs in parallel */ +class LocalBuilderNode: public BuilderNode { + public: + std::string build_func; // Build function + + Array Build(const Array& inputs, int verbose) final; + + static constexpr const char* _type_key = "ansor.LocalBuilder"; + TVM_DECLARE_FINAL_OBJECT_INFO(LocalBuilderNode, BuilderNode); +}; + +/*! + * \brief Managed reference to LocalBuilderNode. + * \sa LocalBuilderNode + */ +class LocalBuilder: public Builder { + public: + LocalBuilder(int timeout, int n_parallel, const std::string& build_func); + + TVM_DEFINE_OBJECT_REF_METHODS(LocalBuilder, Builder, LocalBuilderNode); +}; + +/*! \brief LocalRunner that uses local CPU/GPU to measures the time cost of programs */ +class LocalRunnerNode: public RunnerNode { + public: + int number; + int repeat; + int min_repeat_ms; + double cooldown_interval; + + /*! \biref Run measurement and return results */ + Array Run(const Array& inputs, + const Array& build_results, + int verbose) final; + + static constexpr const char* _type_key = "ansor.LocalRunner"; + TVM_DECLARE_FINAL_OBJECT_INFO(LocalRunnerNode, RunnerNode); +}; + +/*! + * \brief Managed reference to LocalRunnerNode. + * \sa LocalRunnerNode + */ +class LocalRunner: public Runner { + public: + LocalRunner(int timeout, int number, int repeat, + int min_repeat_ms, double cooldown_interval); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(LocalRunner, Runner, + LocalRunnerNode); +}; + +/*! + * \brief Measurer that measures the time costs of tvm programs + * This class combines Builder and Runner, and provides a simpler API */ +class ProgramMeasurerNode: public Object { + public: + static const int DEFAULT_MAX_CONTINOUS_ERROR = 150; + + int ct; + int error_ct; // continuous error counter + std::unordered_map best_flops; + std::unordered_map best_state; + std::unordered_map best_ct; + + Builder builder; + Runner runner; + Array callbacks; + int verbose; + int max_continous_error; + + /*! \brief Reset book keeping variables */ + void Reset(); + + /*! \biref Do measurement */ + void Measure(const SearchTask& task, + const SearchPolicy& policy, + const std::vector& inputs, + std::vector* results, + int batch_size = -1); + + /*! \biref Do measurement silently */ + void SilentMeasure(const SearchTask& task, + const std::vector& inputs, + std::vector* results); + + static constexpr const char* _type_key = "ansor.ProgramMeasurer"; + TVM_DECLARE_FINAL_OBJECT_INFO(ProgramMeasurerNode, Object); +}; + +/*! + * \brief Managed reference to ProgramMeasurerNode. + * \sa ProgramMeasurerNode + */ +class ProgramMeasurer : public ObjectRef { + public: + ProgramMeasurer(Builder builder, Runner runner, + Array callbacks, + int verbose, int max_continous_error = -1); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ProgramMeasurer, ObjectRef, ProgramMeasurerNode); +}; + +} // namespace ansor +} // namespace tvm + +#endif // TVM_ANSOR_MEASURE_H_ diff --git a/src/ansor/search_policy/empty_policy.cc b/src/ansor/search_policy/empty_policy.cc new file mode 100644 index 000000000000..ba861f333c78 --- /dev/null +++ b/src/ansor/search_policy/empty_policy.cc @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "empty_policy.h" + +#include + +namespace tvm { +namespace ansor { + +TVM_REGISTER_NODE_TYPE(EmptyPolicyNode); + +State EmptyPolicyNode::Search(SearchTask task, int n_trials, int early_stopping, + int num_measure_per_iter, int verbose, ProgramMeasurer measurer, + Array pre_search_callbacks) { + cur_task = task; + + // Run pre_search_callbacks before the search process + // This Interface is usually used to set some init status + RunCallbacks(pre_search_callbacks); + + if (n_trials <= 1) { + const auto& res = SearchOneRound(); + CHECK_GT(res.size(), 0); + return res[0]; + } else { + std::vector inputs; + std::vector results; + + measurer->Reset(); + int ct = 0; + // In each round, we call SearchOneRound to get several candidate states, + // then use ProgramMeasurer to test their performance + while (ct < n_trials) { + const auto& res = SearchOneRound(); + ct += res.size(); + inputs.clear(); + for (const auto& state : res) { + inputs.emplace_back(cur_task, state); + } + measurer->Measure(cur_task, GetRef(this), inputs, &results); + } + + // Return a state with best measured performance + return measurer->best_state[cur_task->workload_key]; + } +} + +std::pair, Array > EmptyPolicyNode::ContinueSearchOneRound( + SearchTask task, int num_measure, int verbose, ProgramMeasurer measurer) { + // The whole process is almost the same as Search, while this function is designed to be + // called and managed by another global task scheduler + + std::vector inputs; + std::vector results; + + const auto& res = SearchOneRound(); + for (const auto& state : res) { + inputs.emplace_back(cur_task, state); + } + measurer->Measure(cur_task, GetRef(this), inputs, &results); + + // Return a pair of MeasureInput Array and MeasureResult Array + Array inputs_arr(std::make_move_iterator(inputs.begin()), + std::make_move_iterator(inputs.end())); + Array results_arr(std::make_move_iterator(results.begin()), + std::make_move_iterator(results.end())); + return std::make_pair(std::move(inputs_arr), std::move(results_arr)); +} + +std::vector EmptyPolicyNode::SearchOneRound() { + std::vector res; + res.push_back(cur_task->compute_dag.GetInitState()); + // As an example policy, EmptyPolicy always return a init state + return res; +} + +TVM_REGISTER_GLOBAL("ansor.EmptyPolicy") +.set_body_typed([]() { return EmptyPolicy(make_object()); }); + +} // namespace ansor +} // namespace tvm diff --git a/src/ansor/search_policy/empty_policy.h b/src/ansor/search_policy/empty_policy.h new file mode 100644 index 000000000000..5c2f52608fe0 --- /dev/null +++ b/src/ansor/search_policy/empty_policy.h @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ansor/search_policy/empty_policy.h + * \brief This is an basic example of search policy + */ + +#ifndef TVM_ANSOR_SEARCH_POLICY_EMPTY_POLICY_H_ +#define TVM_ANSOR_SEARCH_POLICY_EMPTY_POLICY_H_ + +#include +#include + +#include "search_policy.h" + +namespace tvm { +namespace ansor { + +/*! + * \file ansor/search_policy/empty_policy.h + * \brief This is an basic example for search policy. The EmptyPolicy will + * always generates the init state of a ComputeDAG. + */ +class EmptyPolicyNode : public SearchPolicyNode { + public: + /*! \brief Search and make n_trails measurements. + * \returns the best state + */ + State Search(SearchTask task, int n_trials, + int early_stopping, int num_measure_per_iter, + int verbose, ProgramMeasurer measurer, + Array pre_search_callbacks) final; + + /*! \brief Continue search for one round. This is used by JointTuner + * \returns the measurement pairs + */ + std::pair, Array > ContinueSearchOneRound( + SearchTask task, int num_measure, int verbose, ProgramMeasurer measurer) final; + + static constexpr const char *_type_key = "ansor.EmptyPolicy"; + TVM_DECLARE_FINAL_OBJECT_INFO(EmptyPolicyNode, SearchPolicyNode); + + private: + /*! + * \brief Usually we need a sub function to generate several candidate states in each + * search round. + * \returns Several generated states + */ + std::vector SearchOneRound(); +}; + +/*! + * \brief Managed reference to EmptyPolicyNode. + * \sa EmptyPolicyNode + */ +class EmptyPolicy : public SearchPolicy { + public: + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(EmptyPolicy, SearchPolicy, EmptyPolicyNode); +}; + +} // namespace ansor +} // namespace tvm + +#endif // TVM_ANSOR_SEARCH_POLICY_EMPTY_POLICY_H_ \ No newline at end of file diff --git a/src/ansor/search_policy/search_policy.cc b/src/ansor/search_policy/search_policy.cc new file mode 100644 index 000000000000..e7a12702ba70 --- /dev/null +++ b/src/ansor/search_policy/search_policy.cc @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ansor/search_policy/search_policy.cc + * \brief The base class for search policy + */ + +#include "search_policy.h" +#include +#include "../serialization.h" + +namespace tvm { +namespace ansor { + +TVM_REGISTER_OBJECT_TYPE(SearchCallbackNode); +TVM_REGISTER_OBJECT_TYPE(SearchPolicyNode); + +void SearchPolicyNode::RunCallbacks(const Array& callbacks) { + if (callbacks.defined() && callbacks.size()) { + for (const auto& callback : callbacks) { + callback->callback(this); + } + } +} + +// Search Policy +TVM_REGISTER_GLOBAL("ansor.SearchPolicyContinueSearchOneRound") +.set_body_typed([](SearchPolicy policy, SearchTask task, int num_measure, + int verbose, ProgramMeasurer measurer) { + Array inputs; + Array results; + std::tie(inputs, results) = policy->ContinueSearchOneRound(task, num_measure, verbose, measurer); + return Array{inputs, results}; +}); + +TVM_REGISTER_GLOBAL("ansor.SearchPolicyRunCallbacks") +.set_body_typed([](SearchPolicy policy, Array callbacks) { + policy->RunCallbacks(callbacks); +}); + +TVM_REGISTER_GLOBAL("ansor.SearchPolicySetTask") +.set_body_typed([](SearchPolicy policy, SearchTask task) { + policy->cur_task = task; +}); + +TVM_REGISTER_GLOBAL("ansor.SearchPolicySetVerbose") +.set_body_typed([](SearchPolicy policy, int verbose) { + policy->verbose = verbose; +}); + +} // namespace ansor +} // namespace tvm diff --git a/src/ansor/search_policy/search_policy.h b/src/ansor/search_policy/search_policy.h new file mode 100644 index 000000000000..eb4703be1914 --- /dev/null +++ b/src/ansor/search_policy/search_policy.h @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ansor/search_policy/search_policy.h + * \brief The base class for search policy + */ + +#ifndef TVM_ANSOR_SEARCH_POLICY_SEARCH_POLICY_H_ +#define TVM_ANSOR_SEARCH_POLICY_SEARCH_POLICY_H_ + +#include "../search_task.h" +#include +#include +#include +#include +#include +#include "../measure.h" + +namespace tvm { +namespace ansor { + +class SearchPolicyNode; + +/*! \brief Callback function to be called before or after the search process */ +class SearchCallbackNode : public Object { + public: + virtual void callback(SearchPolicyNode* policy) = 0; + + static constexpr const char *_type_key = "ansor.SearchCallback"; + TVM_DECLARE_BASE_OBJECT_INFO(SearchCallbackNode, Object); +}; +TVM_DEFINE_MUTABLE_OBJECT_REF(SearchCallback, SearchCallbackNode); + +/*! \brief The base class for search policy */ +class SearchPolicyNode : public Object { + public: + SearchTask cur_task; // The current task + int verbose; // Verbose level (0 means silent) + + void VisitAttrs(AttrVisitor* v) { + v->Visit("cur_task", &cur_task); + v->Visit("verbose", &verbose); + } + + // Search for a task + virtual State Search(SearchTask task, int n_trials, + int early_stopping, int num_measure_per_iter, + int verbose, ProgramMeasurer measurer, + Array pre_search_callbacks) = 0; + + // Continue search one round for a task. + // This is used in the task scheduler for searching for multiple tasks together. + virtual std::pair, Array > ContinueSearchOneRound( + SearchTask task, int num_measure, int verbose, ProgramMeasurer measurer) = 0; + + // Run a list of callback functions + void RunCallbacks(const Array& callbacks); + + static constexpr const char *_type_key = "ansor.SearchPolicy"; + TVM_DECLARE_BASE_OBJECT_INFO(SearchPolicyNode, Object); + + protected: + // The set of the already measured states. + // We store the string format for redundancy check + std::unordered_set measured_states_set_; + // The array of already measured states. + std::vector measured_states_vector_; + // The throughputs of already measured states + std::vector measured_states_throughputs_; +}; +TVM_DEFINE_MUTABLE_OBJECT_REF(SearchPolicy, SearchPolicyNode); + +} // namespace ansor +} // namespace tvm + +#endif // TVM_ANSOR_SEARCH_POLICY_SEARCH_POLICY_H_ diff --git a/src/ansor/search_task.cc b/src/ansor/search_task.cc new file mode 100644 index 000000000000..6be4773fe780 --- /dev/null +++ b/src/ansor/search_task.cc @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ansor/search_task.cc + * \brief Meta information and hardware parameters for a search task + */ + +#include "search_task.h" +#include +#include +#include +#include +#include + +namespace tvm { +namespace ansor { + +TVM_REGISTER_NODE_TYPE(HardwareParamsNode); +TVM_REGISTER_NODE_TYPE(SearchTaskNode); + +HardwareParams::HardwareParams(int num_cores, int vector_unit_bytes, + int cache_line_bytes, int max_unroll_vec, + int max_innermost_split_factor) { + auto node = make_object(); + node->num_cores = num_cores; + node->vector_unit_bytes = vector_unit_bytes; + node->cache_line_bytes = cache_line_bytes; + node->max_unroll_vec = max_unroll_vec; + node->max_innermost_split_factor = max_innermost_split_factor; + data_ = std::move(node); +} + +HardwareParams HardwareParamsNode::GetDefaultHardwareParams( + const Target& target, const Target& target_host) { + if (target->target_name == "llvm") { + return HardwareParams(tvm::runtime::threading::MaxConcurrency(), + 32, 64, 16, 64); + } else { + LOG(FATAL) << "No default hardware parameters for target: " << target; + } + return HardwareParams(); +} + +SearchTask::SearchTask(ComputeDAG compute_dag, std::string workload_key, + Target target, Target target_host, + HardwareParams hardware_params) { + auto node = make_object(); + node->compute_dag = std::move(compute_dag); + node->workload_key = std::move(workload_key); + node->target = std::move(target); + node->target_host = std::move(target_host); + if (hardware_params.defined()) { + node->hardware_params = std::move(hardware_params); + } else { + node->hardware_params = HardwareParamsNode::GetDefaultHardwareParams( + node->target, node->target_host); + } + data_ = std::move(node); +} + +TVM_REGISTER_GLOBAL("ansor.HardwareParams") +.set_body_typed([](int num_cores, int vector_unit_bytes, + int cache_line_bytes, int max_unroll_vec, + int max_innermost_split_factor) { + return HardwareParams(num_cores, vector_unit_bytes, cache_line_bytes, + max_unroll_vec, max_innermost_split_factor); +}); + +TVM_REGISTER_GLOBAL("ansor.SearchTask") +.set_body_typed([](ComputeDAG compute_dag, std::string workload_key, + Target target, Target target_host, + HardwareParams hardware_params) { + return SearchTask(compute_dag, workload_key, target, target_host, + hardware_params); +}); + +} // namespace ansor +} // namespace tvm diff --git a/src/ansor/search_task.h b/src/ansor/search_task.h new file mode 100644 index 000000000000..0f270d105d73 --- /dev/null +++ b/src/ansor/search_task.h @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ansor/search_task.h + * \brief Meta information and hardware parameters for a search task + */ + +#ifndef TVM_ANSOR_SEARCH_TASK_H_ +#define TVM_ANSOR_SEARCH_TASK_H_ + +#include +#include +#include "compute_dag.h" + +namespace tvm { +namespace ansor { + +class HardwareParams; + +/*! \brief Hardware related parameters */ +class HardwareParamsNode : public Object { + public: + // The number of cores + int num_cores; + // The width of vector units in bytes + int vector_unit_bytes; + // The size of cache line in bytes + int cache_line_bytes; + // The max length of an axis to be unrolled or vectorized + int max_unroll_vec; + // The max split factor for the innermost tile + int max_innermost_split_factor; + + // Limitation params for GPU + int max_shared_memory_per_block{INT32_MAX}; + int max_registers_per_block{INT32_MAX}; + int max_threads_per_block{INT32_MAX}; + int max_vthread_extent{INT32_MAX}; + int warp_size{INT32_MAX}; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("num_cores", &num_cores); + v->Visit("vector_unit_bytes", &vector_unit_bytes); + v->Visit("cache_line_bytes", &cache_line_bytes); + v->Visit("max_unroll_vec", &max_unroll_vec); + v->Visit("max_innermost_split_factor", &max_innermost_split_factor); + + v->Visit("max_shared_memory_per_block", &max_shared_memory_per_block); + v->Visit("max_registers_per_block", &max_registers_per_block); + v->Visit("max_threads_per_block", &max_threads_per_block); + v->Visit("max_vthread_extent", &max_vthread_extent); + v->Visit("warp_size", &warp_size); + } + + static HardwareParams GetDefaultHardwareParams(const Target& target, + const Target& target_host); + + static constexpr const char* _type_key = "ansor.HardwareParams"; + TVM_DECLARE_FINAL_OBJECT_INFO(HardwareParamsNode, Object); +}; + +/*! + * \brief Managed reference to HardwareParamsNode. + * \sa HardwareParamsNode + */ +class HardwareParams : public ObjectRef { + public: + HardwareParams(int num_cores, int vector_unit_bytes, int cache_line_bytes, + int max_unroll_vec, int max_innermost_split_factor); + + TVM_DEFINE_OBJECT_REF_METHODS(HardwareParams, ObjectRef, HardwareParamsNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(HardwareParamsNode); +}; + +/*! \brief Meta-info for a search task */ +class SearchTaskNode : public Object { + public: + ComputeDAG compute_dag; + std::string workload_key; + Target target; + Target target_host; + HardwareParams hardware_params; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("compute_dag", &compute_dag); + v->Visit("workload_key", &workload_key); + v->Visit("target", &target); + v->Visit("target_host", &target_host); + v->Visit("hardware_params", &hardware_params); + } + + static constexpr const char* _type_key = "ansor.SearchTask"; + TVM_DECLARE_FINAL_OBJECT_INFO(SearchTaskNode, Object); +}; + +/*! + * \brief Managed reference to SearchTaskNode. + * \sa SearchTaskNode + */ +class SearchTask : public ObjectRef { + public: + SearchTask(ComputeDAG compute_dag, std::string workload_key, + Target target, Target target_host, + HardwareParams hardware_params); + + TVM_DEFINE_OBJECT_REF_METHODS(SearchTask, ObjectRef, SearchTaskNode); +}; + +} // namespace ansor +} // namespace tvm + +#endif // TVM_ANSOR_SEARCH_TASK_H_ diff --git a/src/ansor/serialization.cc b/src/ansor/serialization.cc new file mode 100644 index 000000000000..939fca83f1fb --- /dev/null +++ b/src/ansor/serialization.cc @@ -0,0 +1,503 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ansor/serialization.cc + * \brief Json serialization format for dumping and loading tuning records + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "serialization.h" +#include "loop_state.h" +#include "transform_step.h" +#include "utils.h" + +// Json serialization handler for MeasureInput, MeasureResult +// (and recursively for SearchTask, State, Step, ...) +namespace dmlc { +namespace json { + +inline std::vector& IntArrayToVector(std::vector* out, + const ::tvm::Array<::tvm::PrimExpr>& data) { + out->clear(); + for (const auto&x : data) { + auto pi = x.as<::tvm::tir::IntImmNode>(); + CHECK(pi != nullptr) << "Can only contain int values"; + out->push_back(pi->value); + } + return *out; +} + +template <> +struct Handler > { + inline static void Write(dmlc::JSONWriter* writer, + const std::vector<::tvm::ansor::Stage> & data) { + writer->BeginArray(false); + writer->EndArray(); + } + inline static void Read(dmlc::JSONReader* reader, + std::vector<::tvm::ansor::Stage> * data) { + bool s; + reader->BeginArray(); + s = reader->NextArrayItem(); CHECK(!s); + } +}; + +template <> +struct Handler > { + inline static void Write(dmlc::JSONWriter* writer, + const std::vector<::tvm::ansor::Step> & data) { + std::vector tmp; + writer->BeginArray(false); + for (size_t i = 0; i < data.size(); ++i) { + writer->WriteArraySeperator(); + writer->BeginArray(false); + if (auto ps = data[i].as<::tvm::ansor::ReorderStepNode>()) { + writer->WriteArrayItem(std::string("RE")); + writer->WriteArrayItem(ps->stage_id); + + writer->WriteArraySeperator(); + writer->BeginArray(false); + for (int x : ps->after_ids) { + writer->WriteArrayItem(x); + } + writer->EndArray(); + } else if (auto ps = data[i].as<::tvm::ansor::SplitStepNode>()) { + writer->WriteArrayItem(std::string("SP")); + writer->WriteArrayItem(ps->stage_id); + writer->WriteArrayItem(ps->iter_id); + if (ps->extent.defined()) { + writer->WriteArrayItem(::tvm::ansor::GetIntImm(ps->extent)); + } else { + writer->WriteArrayItem(0); + } + writer->WriteArrayItem(IntArrayToVector(&tmp, ps->lengths)); + writer->WriteArrayItem(static_cast(ps->inner_to_outer)); + } else if (auto ps = data[i].as<::tvm::ansor::FuseStepNode>()) { + writer->WriteArrayItem(std::string("FU")); + writer->WriteArrayItem(ps->stage_id); + + writer->WriteArraySeperator(); + writer->BeginArray(false); + for (int x : ps->fused_ids) { + writer->WriteArrayItem(x); + } + writer->EndArray(); + } else { + LOG(FATAL) << "Invalid step: " << data[i]; + } + writer->EndArray(); + } + writer->EndArray(); + } + + inline static void Read(dmlc::JSONReader* reader, + std::vector<::tvm::ansor::Step> * data) { + std::vector int_list; + bool s, inner_to_outer; + std::string name, scope_name, pragma_type, ti_func_name; + int stage_id, iter_id, extent; + + reader->BeginArray(); + data->clear(); + while (reader->NextArrayItem()) { + reader->BeginArray(); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&name); + if (name == "RE") { + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&stage_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&int_list); + data->push_back(::tvm::ansor::ReorderStep(stage_id, int_list)); + } else if (name == "SP") { + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&stage_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&iter_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&extent); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&int_list); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&inner_to_outer); + data->push_back(::tvm::ansor::SplitStep( + stage_id, iter_id, extent, + std::vector<::tvm::PrimExpr>(int_list.begin(), int_list.end()), + inner_to_outer)); + } else if (name == "FU") { + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&stage_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&int_list); + data->push_back(::tvm::ansor::FuseStep(stage_id, int_list)); + } else { + LOG(FATAL) << "Invalid step format"; + } + s = reader->NextArrayItem(); CHECK(!s); + } + } +}; + +template <> +struct Handler<::tvm::ansor::StateNode> { + inline static void Write(dmlc::JSONWriter* writer, + const ::tvm::ansor::StateNode& data) { + writer->BeginArray(false); + writer->WriteArrayItem(data.stages); + writer->WriteArrayItem(data.transform_steps); + writer->EndArray(); + } + inline static void Read(dmlc::JSONReader* reader, + ::tvm::ansor::StateNode* data) { + reader->BeginArray(); + bool s; + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&data->stages); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&data->transform_steps); + s = reader->NextArrayItem(); CHECK(!s); + } +}; + +template <> +struct Handler<::tvm::ansor::SearchTaskNode> { + inline static void Write(dmlc::JSONWriter* writer, + const ::tvm::ansor::SearchTaskNode& data) { + writer->BeginArray(false); + writer->WriteArrayItem(data.workload_key); + writer->WriteArrayItem(data.target->str()); + writer->EndArray(); + } + inline static void Read(dmlc::JSONReader* reader, + ::tvm::ansor::SearchTaskNode* data) { + std::string target_str; + bool s; + + reader->BeginArray(); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&data->workload_key); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&target_str); + data->target = ::tvm::Target::Create(target_str); + s = reader->NextArrayItem(); CHECK(!s); + } +}; + +template <> +struct Handler<::tvm::ansor::MeasureInputNode> { + inline static void Write(dmlc::JSONWriter* writer, + const ::tvm::ansor::MeasureInputNode& data) { + writer->BeginArray(false); + writer->WriteArrayItem(*data.task.operator->()); + writer->WriteArrayItem(*data.state.operator->()); + writer->EndArray(); + } + inline static void Read(dmlc::JSONReader* reader, + ::tvm::ansor::MeasureInputNode* data) { + bool s; + auto task_node = ::tvm::make_object<::tvm::ansor::SearchTaskNode>(); + auto state_node = ::tvm::make_object<::tvm::ansor::StateNode>(); + state_node->complete = true; + + reader->BeginArray(); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(task_node.get()); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(state_node.get()); + s = reader->NextArrayItem(); CHECK(!s); + + data->task = ::tvm::ansor::SearchTask(task_node); + data->state = ::tvm::ansor::State(state_node); + } +}; + +template <> +struct Handler<::tvm::ansor::MeasureResultNode> { + inline static void Write(dmlc::JSONWriter* writer, + const ::tvm::ansor::MeasureResultNode& data) { + writer->BeginArray(false); + writer->WriteArraySeperator(); + writer->BeginArray(false); + for (const auto&x : data.costs) { + auto pf = x.as<::tvm::tir::FloatImmNode>(); + CHECK(pf != nullptr) << "Cost can only contain float values"; + writer->WriteArrayItem(pf->value); + } + writer->EndArray(); + writer->WriteArrayItem(data.error_no); + writer->WriteArrayItem(data.all_cost); + writer->WriteArrayItem(static_cast((data.timestamp))); + writer->EndArray(); + } + inline static void Read(dmlc::JSONReader* reader, + ::tvm::ansor::MeasureResultNode* data) { + bool s; + std::vector tmp; + + reader->BeginArray(); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&tmp); + data->costs.clear(); + for (const auto& i : tmp) { + data->costs.push_back(::tvm::FloatImm(::tvm::DataType::Float(64), i)); + } + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&data->error_no); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&data->all_cost); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&data->timestamp); + s = reader->NextArrayItem(); CHECK(!s); + } +}; + +} // namespace json +} // namespace dmlc + +namespace tvm { +namespace ansor { + +TVM_REGISTER_OBJECT_TYPE(LogToFileNode); +TVM_REGISTER_OBJECT_TYPE(LogReaderNode); + +const std::string ANSOR_LOG_VERSION = "v0.2"; // NOLINT(*) + +LogToFile::LogToFile(std::string filename) { + auto node = make_object(); + node->filename = std::move(filename); + data_ = std::move(node); +} + +void WriteMeasureRecords(std::ostream* os, + const Array& inputs, + const Array& results) { + dmlc::JSONWriter writer(os); + for (size_t i = 0; i < inputs.size(); ++i) { + writer.BeginObject(false); + writer.WriteObjectKeyValue("i", *inputs[i].operator->()); + writer.WriteObjectKeyValue("r", *results[i].operator->()); + writer.WriteObjectKeyValue("v", ANSOR_LOG_VERSION); + writer.EndObject(); + *os << "\n"; + } +} + +void ReadMeasureRecord(const std::string& str, + MeasureInputNode* inp, + MeasureResultNode* res, + std::string* log_version) { + std::istringstream ss(str); + dmlc::JSONReader reader(&ss); + std::string key; + + reader.BeginObject(); + while (reader.NextObjectItem(&key)) { + if (key == "i") { + reader.Read(inp); + } else if (key == "r") { + reader.Read(res); + } else if (key == "v") { + reader.Read(log_version); + } else { + LOG(FATAL) << "Invalid key in json log: " << key; + } + } +} + +void LogToFileNode::callback(const SearchPolicy& policy, + const Array& inputs, + const Array& results) { + std::ofstream ofs(filename, std::ofstream::app); + WriteMeasureRecords(&ofs, inputs, results); +} + +LogReader::LogReader(std::string filename) { + auto node = make_object(); + node->filename = filename; + node->infile.open(filename, std::ifstream::in); + data_ = std::move(node); +} + +LogReaderNode::~LogReaderNode() { + infile.close(); +} + +bool LogReaderNode::ReadNext(MeasureInputNode* inp, MeasureResultNode* res) { + std::string log_version; + + while (std::getline(infile, cur_line)) { + if (cur_line[0] == '#' || cur_line[0] == ' ') { + // skip comment lines begin with '#' or ' ' + continue; + } + ReadMeasureRecord(cur_line, inp, res, &log_version); + return true; + } + + return false; +} + +std::pair, Array > LogReaderNode::ReadLines( + int max_size, int skip_size) { + auto inp = make_object(); + auto res = make_object(); + Array inputs; + Array results; + + while (ReadNext(inp.get(), res.get())) { + if (skip_size > 0) { + skip_size--; + continue; + } + + inputs.push_back(inp->copy()); + results.push_back(res->copy()); + + if (max_size > 0 && static_cast(inputs.size()) >= max_size) { + break; + } + } + + return std::make_pair(inputs, results); +} + +std::pair BestMeasurePairInFile( + const std::string& filename, const std::string& workload_key, + const Target& target) { + std::pair best_pair; + double best_cost = 1e30; + + auto inp = make_object(); + auto res = make_object(); + LogReader reader = LogReader(filename); + + while (reader->ReadNext(inp.get(), res.get())) { + if (res->error_no != kNoError || inp->task->workload_key != workload_key + || inp->task->target->target_name != target->target_name) { + continue; + } + + double cost = FloatArrayMean(res->costs); + + if (cost < best_cost) { + best_cost = cost; + best_pair = std::make_pair(inp->copy(), res->copy()); + } + } + + return best_pair; +} + +TVM_REGISTER_GLOBAL("ansor.LogToFile").set_body_typed([](const std::string& filename) { + return LogToFile(filename); +}); + +TVM_REGISTER_GLOBAL("ansor.LogReader").set_body_typed([](const std::string& filename) { + return LogReader(filename); +}); + +TVM_REGISTER_GLOBAL("ansor.LogReaderReadLines") +.set_body_typed([](LogReader reader, int size, int skip_size) { + const auto& res = reader->ReadLines(size, skip_size); + return Array{res.first, res.second}; +}); + +TVM_REGISTER_GLOBAL("ansor.LogReaderReadNext") +.set_body_typed([](LogReader reader) { + auto inp = make_object(); + auto res = make_object(); + if (reader->ReadNext(inp.get(), res.get())) { + return Array{ObjectRef(inp), ObjectRef(res)}; + } else { + return Array(); + } +}); + +TVM_REGISTER_GLOBAL("ansor.WriteMeasureRecordsToFile") +.set_body([](TVMArgs args, TVMRetValue *ret) { + std::string filename = args[0]; + Array in = args[1]; + Array res = args[2]; + std::ofstream ofs(filename, std::ofstream::app); + WriteMeasureRecords(&ofs, in, res); +}); + +TVM_REGISTER_GLOBAL("ansor.GetStatesFromMeasureInputs") +.set_body([](TVMArgs args, TVMRetValue *ret) { + Array inputs = args[0]; + SearchTask external_task; + + if (args.size() > 1) { + external_task = args[1]; + } + + Array states; + states.reserve(inputs.size()); + + // (workload_key, target) -> (search_task) + std::unordered_map, SearchTask> task_cache; + + for (const auto& inp : inputs) { + const std::string& workload_key = inp->task->workload_key; + std::pair key(workload_key, inp->task->target->str()); + + const SearchTaskNode* ptask; + if (external_task.defined()) { + ptask = external_task.operator->(); + } else { + auto find_res = task_cache.find(key); + if (find_res == task_cache.end()) { + if (inp->task->compute_dag.defined()) { // the measure input is complete + ptask = inp->task.operator->(); + } else { // the measure input is incomplete + // rebuild task for incomplete measure pairs read from file + SearchTask new_task = SearchTask( + ComputeDAG(workload_key), + workload_key, + inp->task->target, + inp->task->target_host, + inp->task->hardware_params); + task_cache.insert(std::make_pair(key, new_task)); + ptask = new_task.operator->(); + } + } else { + ptask = find_res->second.operator->(); + } + } + + State tmp_s = ptask->compute_dag.GetInitState(); + StateNode *ps = tmp_s.CopyOnWrite(); + ps->transform_steps = inp->state->transform_steps; + tmp_s.DoSteps(ps->transform_steps, ptask->compute_dag); + states.push_back(std::move(tmp_s)); + } + + *ret = states; +}); + +} // namespace ansor +} // namespace tvm diff --git a/src/ansor/serialization.h b/src/ansor/serialization.h new file mode 100644 index 000000000000..82dd036991e6 --- /dev/null +++ b/src/ansor/serialization.h @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ansor/serialization.h + * \brief Json serialization format for dumping and loading tuning records + */ + +#ifndef TVM_ANSOR_SERIALIZATION_H_ +#define TVM_ANSOR_SERIALIZATION_H_ + +#include +#include +#include +#include "measure.h" + +namespace tvm { +namespace ansor { + +/*! \brief Callback for logging the input and results of measurements to file */ +class LogToFileNode : public MeasureCallbackNode { + public: + std::string filename; + + /*! \brief Log measure pairs to file. This is called by the search policy */ + void callback(const SearchPolicy& policy, + const Array& inputs, + const Array& results) final; + + static constexpr const char *_type_key = "ansor.LogToFile"; + TVM_DECLARE_FINAL_OBJECT_INFO(LogToFileNode, MeasureCallbackNode); +}; + +/*! + * \brief Managed reference to LogToFileNode. + * \sa LogToFileNode + */ +class LogToFile : public MeasureCallback { + public: + explicit LogToFile(std::string filename); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(LogToFile, MeasureCallback, LogToFileNode); +}; + +/*! \brief Log reader to load step logs from a target file.*/ +class LogReaderNode : public Object { + public: + std::string filename; + std::ifstream infile; + + ~LogReaderNode(); + + /*! \brief Read next line in the log file + * \return Whether the read is successful */ + bool ReadNext(MeasureInputNode* inp, MeasureResultNode* res); + + /*! \brief Read multiple lines from the log file + * \param max_size The maximum number of lines. -1 means read all lines + * \param skip_size Skip the first n lines */ + std::pair, Array > ReadLines( + int max_size = -1, int skip_size = 0); + + static constexpr const char* _type_key = "ansor.LogReader"; + TVM_DECLARE_FINAL_OBJECT_INFO(LogReaderNode, Object); + + private: + std::string cur_line; +}; + +/*! + * \brief Managed reference to LogReaderNode. + * \sa LogReaderNode + */ +class LogReader : public ObjectRef { + public: + explicit LogReader(std::string filename); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(LogReader, ObjectRef, LogReaderNode); +}; + +/*! \brief Write measure records to an output stream */ +void WriteMeasureRecords(std::ostream* os, + const Array& inputs, + const Array& results); + +/*! \brief Read one measure record from a string */ +void ReadMeasureRecord(const std::string& str, + MeasureInputNode* inp, + MeasureResultNode* res, + std::string* log_version); + +/*! \brief Return the best measure pair with lowest cost in a file */ +std::pair BestMeasurePairInFile(const std::string& filename, + const std::string& workload_key, + const Target& target); + +} // namespace ansor +} // namespace tvm + +#endif // TVM_ANSOR_SERIALIZATION_H_ diff --git a/src/ansor/transform_step.cc b/src/ansor/transform_step.cc new file mode 100644 index 000000000000..1bcea3f690c9 --- /dev/null +++ b/src/ansor/transform_step.cc @@ -0,0 +1,240 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ansor/transform_step.cc + * \brief Transformation steps. For each schedule primitive, there is a corresponding transform step. + * + * See the note in transform_step.h on how to add a new step + */ + +#include "transform_step.h" +#include +#include +#include +#include "utils.h" + +namespace tvm { +namespace ansor { + +/********** Reorder **********/ +ReorderStep::ReorderStep(int stage_id, const std::vector& after_ids) { + auto node = make_object(); + node->stage_id = stage_id; + node->after_ids = after_ids; + data_ = std::move(node); +} + +void ReorderStepNode::ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const { + te::Stage& stage = (*stages)[stage_id]; + const std::vector& axes = (*stage_to_axes)[stage]; + CHECK_EQ(after_ids.size(), axes.size()); + + std::vector new_axes; + new_axes.reserve(axes.size()); + for (auto i : after_ids) { + new_axes.push_back(axes[i]); + } + stage.reorder(new_axes); + (*stage_to_axes)[stage] = std::move(new_axes); +} + +std::string ReorderStepNode::PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const { + const te::Stage& stage = (*stages)[stage_id]; + std::stringstream ss; + + ss << "s[" << CleanName(stage->op->name) << "].reorder("; + for (size_t i = 0; i < after_ids.size(); ++i) { + ss << CleanName((*stage_to_axes)[stage][after_ids[i]]->var->name_hint); + if (i != after_ids.size() - 1) { + ss << ", "; + } + } + ss << ")\n"; + + ApplyToSchedule(stages, stage_to_axes); + return ss.str(); +} + +/********** Split **********/ +std::vector ApplySplitToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes, + int stage_id, + int iter_id, + const std::vector& lengths, + bool inner_to_outer) { + te::Stage& stage = (*stages)[stage_id]; + const std::vector& axes = (*stage_to_axes)[stage]; + + std::vector outs; + if (inner_to_outer) { + IterVar outer = axes[iter_id], inner; + for (int i = static_cast(lengths.size()) - 1; i >= 0; i--) { + IterVar to_split = outer; + stage.split(to_split, lengths[i], &outer, &inner); + outs.push_back(inner); + } + outs.push_back(outer); + } else { + IterVar outer, inner = axes[iter_id]; + for (size_t i = 0; i < lengths.size(); i++) { + IterVar to_split = inner; + stage.split_by_nparts(to_split, lengths[i], &outer, &inner); + outs.push_back(outer); + } + outs.push_back(inner); + } + + std::vector new_axes; + new_axes.insert(new_axes.end(), axes.begin(), axes.begin() + iter_id); + if (inner_to_outer) { + new_axes.insert(new_axes.end(), outs.rbegin(), outs.rend()); + } else { + new_axes.insert(new_axes.end(), outs.begin(), outs.end()); + } + new_axes.insert(new_axes.end(), axes.begin() + iter_id + 1, axes.end()); + (*stage_to_axes)[stage] = std::move(new_axes); + + return outs; +} + +std::string PrintSplitAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + int stage_id, + int iter_id, + const std::vector& lengths, + bool inner_to_outer) { + te::Stage& stage = (*stages)[stage_id]; + auto to_split = (*stage_to_axes)[stage][iter_id]; + const auto& func_name = CleanName(stage->op->name); + const auto& outs = ApplySplitToSchedule(stages, stage_to_axes, stage_id, + iter_id, lengths, inner_to_outer); + + std::stringstream ss; + int size = static_cast(lengths.size()); + if (inner_to_outer) { + for (int i = size - 1; i >= 0; i--) { + ss << CleanName(outs[size - i]->var->name_hint) << ", " + << CleanName(outs[size - i - 1]->var->name_hint) + << " = s[" << func_name << "].split(" + << CleanName(to_split->var->name_hint) + << ", factor=" << lengths[i] << ")\n"; + to_split = outs[size - i]; + } + } else { + for (int i = 0; i < size; i++) { + ss << CleanName(outs[i]->var->name_hint) << ", " + << CleanName(outs[i + 1]->var->name_hint) + << " = s[" << func_name << "].split(" + << CleanName(to_split->var->name_hint) + << ", nparts=" << lengths[i] << ")\n"; + to_split = outs[i + 1]; + } + } + + return ss.str(); +} + +SplitStep::SplitStep(int stage_id, int iter_id, PrimExpr extent, + const std::vector& lengths, + bool inner_to_outer) { + auto node = make_object(); + node->stage_id = stage_id; + // Extent can be a unreducible expression in some special cases + if (extent->IsInstance()) { + node->extent = std::move(extent); + } + node->iter_id = iter_id; + node->lengths = lengths; + node->inner_to_outer = inner_to_outer; + data_ = std::move(node); +} + +std::vector SplitStepNode::ApplyToSchedule( + std::vector *stages, StageToAxesMap *stage_to_axes) const { + return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, + lengths, inner_to_outer); +} + +std::string SplitStepNode::PrintAsPythonAPI( + std::vector *stages, StageToAxesMap *stage_to_axes, + te::Schedule *schedule, const std::vector& transform_steps) const { + return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, + lengths, inner_to_outer); +} + +/********** Fuse **********/ +FuseStep::FuseStep(int stage_id, const std::vector& fused_ids) { + auto node = make_object(); + node->stage_id = stage_id; + node->fused_ids = fused_ids; + data_ = std::move(node); +} + +IterVar FuseStepNode::ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const { + te::Stage& stage = (*stages)[stage_id]; + const std::vector& axes = (*stage_to_axes)[stage]; + + Array to_fuse; + for (auto i : fused_ids) { + to_fuse.push_back(axes[i]); + } + IterVar fused_axis; + stage.fuse(to_fuse, &fused_axis); + std::vector new_axes; + new_axes.insert(new_axes.end(), axes.begin(), axes.begin() + fused_ids[0]); + new_axes.push_back(fused_axis); + new_axes.insert(new_axes.end(), axes.begin() + fused_ids.back() + 1, + axes.end()); + (*stage_to_axes)[stage] = std::move(new_axes); + + return fused_axis; +} + +std::string FuseStepNode::PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const { + const auto& stage = (*stages)[stage_id]; + std::stringstream to_fuse; + + for (size_t i = 0; i < fused_ids.size(); ++i) { + to_fuse << CleanName((*stage_to_axes)[stage][fused_ids[i]]->var->name_hint); + if (i != fused_ids.size() - 1) { + to_fuse << ", "; + } + } + + std::stringstream ss; + const auto& fused = ApplyToSchedule(stages, stage_to_axes); + + ss << CleanName(fused->var->name_hint) << " = s[" + << CleanName(stage->op->name) << "].fuse(" + << to_fuse.str() << ")\n"; + + return ss.str(); +} + +} // namespace ansor +} // namespace tvm diff --git a/src/ansor/transform_step.h b/src/ansor/transform_step.h new file mode 100644 index 000000000000..8eff6a4e7536 --- /dev/null +++ b/src/ansor/transform_step.h @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ansor/transform_step.h + * \brief Transformation steps. For each schedule primitive, there is a corresponding transform step. + * + * \note How to add a new transform step. + * Take fuse for example: + * 1. Define class `FuseStepNode`, `FuseStep` in `transform_steps.h`, and implement its construction + * function `FuseStep::FuseStep(...)` in `transform_steps.cc` + * 2. Implement `FuseStepNode::ApplyToSchedule` and `FuseStepNode::PrintAsPythonAPI`. + * - In these two functions you need to lower this step with tvm's te schedule API + * 3. Implement `State::fuse` and `State::DoFuseStep`. + * - In these two functions you need to incrementally update all data structures in State with + * CopyOnWrite style + * 4. Add you step to `ComputeDAG::ReplaySteps` and make sure it works. + * 5. Add serialization support in `struct Handler >` + * in `serialization.cc` + * 6. Add hash support in `struct hash<::tvm::ansor::Step>` (search for this function in this file) + * 7. Add its corresponding Python API to `loop_state.py` and necessary unit test + */ + +#ifndef TVM_ANSOR_TRANSFORM_STEP_H_ +#define TVM_ANSOR_TRANSFORM_STEP_H_ + +#include +#include +#include +#include "loop_state.h" + +namespace tvm { +namespace ansor { + +using namespace tvm::tir; + +/*! \brief Reorder step that corresponds to te::Stage::reorder */ +class ReorderStepNode: public StepNode { + public: + std::vector after_ids; // The iterator ids after reorder. + // This array should specify the order of all iterators. + + void ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.ReorderStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(ReorderStepNode, Object); +}; + +/*! + * \brief Managed reference to ReorderStepNode. + * \sa ReorderStepNode + */ +class ReorderStep : public Step { + public: + ReorderStep(int stage_id, const std::vector& after_ids); + + TVM_DEFINE_OBJECT_REF_METHODS(ReorderStep, Step, ReorderStepNode); +}; + +/*! \brief Split step that corresponds to te::Stage::split with additional + * support of multiple-level of factors */ +class SplitStepNode: public StepNode { + public: + int iter_id; // The id of the iter to split + PrimExpr extent; // the extent length of the axis to split + std::vector lengths; // The split factors + bool inner_to_outer; // If true, the `lengths` denote the lengths of + // iterators from inner level to outer level + + std::vector ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.SplitStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(SplitStepNode, Object); +}; + +/*! + * \brief Managed reference to SplitStepNode. + * \sa SplitStepNode + */ +class SplitStep : public Step { + public: + SplitStep(int stage_id, int iter_id, PrimExpr extent, + const std::vector& lengths, + bool inner_to_outer); + + TVM_DEFINE_OBJECT_REF_METHODS(SplitStep, Step, SplitStepNode); +}; + +/*! \brief Fuse step that corresponds to te::Stage::fuse */ +class FuseStepNode: public StepNode { + public: + std::vector fused_ids; // The ids of iterators to fuse + + IterVar ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.FuseStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(FuseStepNode, Object); +}; + +/*! + * \brief Managed reference to FuseStepNode. + * \sa FuseStepNode + */ +class FuseStep : public Step { + public: + FuseStep(int stage_id, const std::vector& fused_ids); + + TVM_DEFINE_OBJECT_REF_METHODS(FuseStep, Step, FuseStepNode); +}; + +} // namespace ansor +} // namespace tvm + +// Hash and equal function for Step +namespace std { + +template <> +struct hash<::tvm::ansor::Step> { + std::size_t operator()(const ::tvm::ansor::Step& step) const { + if (auto ps = step.as<::tvm::ansor::ReorderStepNode>()) { + return ::dmlc::HashCombine(1, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ps->after_ids)); + } else if (auto ps = step.as<::tvm::ansor::SplitStepNode>()) { + size_t ret = ::dmlc::HashCombine(2, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ::dmlc::HashCombine(std::hash()(ps->iter_id), + ps->inner_to_outer))); + for (const auto& len : ps->lengths) { + if (len.defined()) { + auto pint = len.as<::tvm::tir::IntImmNode>(); + CHECK(pint != nullptr); + ret = ::dmlc::HashCombine(ret, pint->value); + } else { + ret = ::dmlc::HashCombine(ret, 0x5D); // a magic number + } + } + return ret; + } else if (auto ps = step.as<::tvm::ansor::FuseStepNode>()) { + return ::dmlc::HashCombine(3, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ps->fused_ids)); + } else { + LOG(FATAL) << "Invalid step"; + } + return 0; + } +}; +} // namespace std + +#endif // TVM_ANSOR_TRANSFORM_STEP_H_ diff --git a/src/ansor/utils.cc b/src/ansor/utils.cc new file mode 100644 index 000000000000..27aac7e8b315 --- /dev/null +++ b/src/ansor/utils.cc @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ansor/utils.cc + * \brief Common utilities + */ + +#include "utils.h" +#include + +namespace tvm { +namespace ansor { + +NullStream& NullStream::Global() { + static NullStream stream; + return stream; +} + +const std::vector >& SplitFactorizationMemo::GetFactorizationSchemes( + int extent, int n_lengths, int max_innermost_factor) { + QueryKey key = std::make_tuple(extent, n_lengths, max_innermost_factor); + auto it = memory_.find(key); + if (it != memory_.end()) { + return it->second; + } + + tmp_stack_.assign(n_lengths, PrimExpr()); + results_ = &memory_[key]; + n_lengths_ = n_lengths; + + DfsEnumerate(0, extent, max_innermost_factor); + + return *results_; +} + +void SplitFactorizationMemo::DfsEnumerate(int now, int remaining_lenght, int max_innermost_factor) { + if (now == n_lengths_) { + if (tmp_stack_.back().as()->value <= max_innermost_factor) { + results_->push_back(tmp_stack_); + } + } else { + for (const auto& f : GetFactors(remaining_lenght)) { + tmp_stack_[now] = PrimExpr(f); + DfsEnumerate(now + 1, remaining_lenght / f, max_innermost_factor); + } + } +} + +const std::vector& SplitFactorizationMemo::GetFactors(int n) { + auto it = factor_memory_.find(n); + if (it != factor_memory_.end()) { + return it->second; + } + + std::vector& res = factor_memory_[n]; + int step = n % 2 == 0 ? 1 : 2; + for (size_t i = 1; i < static_cast(std::sqrt(n)) + 1; i += step) { + if (n % i == 0) { + res.push_back(i); + if (n / i != i) { + res.push_back(n/i); + } + } + } + std::sort(res.begin(), res.end()); + return res; +} + +ThreadPool& ThreadPool::Global() { + static ThreadPool* pool = new ThreadPool(); + static int ct = 0; + + ct = (ct + 1) % ThreadPool::REFRESH_EVERY; + + if (ct == 0) { + pool->Abort(); + delete pool; + pool = new ThreadPool(); + } + + if (pool->NumWorkers() == 0) { + pool->Launch(std::thread::hardware_concurrency()); + } + + return *pool; +} + +TVM_REGISTER_GLOBAL("ansor.utils.GetFactorizationSchemes") +.set_body([](TVMArgs args, TVMRetValue *ret) { + int extent = args[0]; + int n_lengths = args[1]; + int max_innermost_factor = args[2]; + SplitFactorizationMemo memo; + + Array > result; + for (const auto& lens : memo.GetFactorizationSchemes(extent, n_lengths, max_innermost_factor)) { + result.push_back(lens); + } + + *ret = result; +}); + +} // namespace ansor +} // namespace tvm diff --git a/src/ansor/utils.h b/src/ansor/utils.h new file mode 100644 index 000000000000..4e98bb907af9 --- /dev/null +++ b/src/ansor/utils.h @@ -0,0 +1,428 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ansor/utils.h + * \brief Common utilities + */ + +#ifndef TVM_ANSOR_UTILS_H_ +#define TVM_ANSOR_UTILS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace std { + +/*! \brief Hash function for std::pair */ +template +struct hash > { + std::size_t operator()(const std::pair& k) const { + return ::dmlc::HashCombine(std::hash()(k.first), std::hash()(k.second)); + } +}; + +/*! \brief Hash function for std::tuple */ +template +struct hash > { + std::size_t operator()(const std::tuple& k) const { + return ::dmlc::HashCombine( + ::dmlc::HashCombine(std::hash()(std::get<0>(k)), std::hash()(std::get<1>(k))), + std::hash()(std::get<2>(k))); + } +}; + +/*! \brief Hash function for std::vector */ +template +struct hash > { + std::size_t operator()(const std::vector& vec) const { + if (vec.empty()) { + return 0; + } + std::size_t ret = std::hash()(vec[0]); + for (size_t i = 1; i < vec.size(); ++i) { + ret = ::dmlc::HashCombine(ret, std::hash()(vec[i])); + } + return ret; + } +}; + +} // namespace std + +namespace tvm { +namespace ansor { + +/*! \brief Macro to make it easy to define mutable object ref type given node */ +#define TVM_DEFINE_MUTABLE_OBJECT_REF(TypeName, ObjectName) \ + class TypeName : public ObjectRef { \ + public: \ + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ObjectRef, ObjectName); \ + }; \ + +/********** Utilities for std::vector, std::set, std::string **********/ +/*! \brief Get the first appearance index of elements in a vector */ +template +inline void GetIndices(const std::vector& array, + const std::vector& to_locate, + std::vector* indices) { + for (const auto& v : to_locate) { + auto it = std::find(array.begin(), array.end(), v); + if (it != array.end()) { + indices->push_back(it - array.begin()); + } else { + LOG(FATAL) << "Cannot find the item"; + } + } +} + +/*! \brief Get the first appearance index of an element in a vector */ +template +inline int GetIndex(const std::vector& array, const T& to_locate) { + for (size_t i = 0; i < array.size(); ++i) { + if (array[i] == to_locate) { + return i; + } + } + LOG(FATAL) << "Cannot find the item"; + return -1; +} + +/*! \brief Delete an element in a vector */ +template +inline void DeleteItem(std::vector* array, const T& to_delete) { + auto iter = std::find(array->begin(), array->end(), to_delete); + if (iter != array->end()) { + array->erase(iter); + } +} + +/*! \brief Compute the product of all elements in a vector */ +inline int64_t ElementProduct(const std::vector& array) { + int64_t ret = 1; + for (auto x : array) { + ret *= x; + } + return ret; +} + +/*! \brief Get the maximum element in a vector */ +template +T MaximumElement(const std::vector& array) { + CHECK(!array.empty()); + const T* pmax = &array[0]; + for (size_t i = 1; i < array.size(); ++i) { + if (array[i] > *pmax) { + pmax = &array[i]; + } + } + return *pmax; +} + +/*! \brief Move elements from multiple vectors to one vector */ +template +std::vector& ConcatenateMove(std::vector* out, std::vector* in) { + out->insert(out->end(), std::make_move_iterator(in->begin()), + std::make_move_iterator(in->end())); + return *out; +} + +/*! \brief Move elements from multiple vectors to one vector */ +template +std::vector& ConcatenateMove(std::vector* out, std::vector* first, Args... args) { + ConcatenateMove(out, first); + ConcatenateMove(out, args...); + return *out; +} + +/*! \brief Get a random permutation of integers [0, n-1] */ +template +void RandomPermutation(int n, std::vector* out, G* gen) { + out->assign(n, 0); + std::iota(out->begin(), out->end(), 0); + std::shuffle(out->begin(), out->end(), *gen); +} + +/*! \brief Random sample without replacement */ +template +void RandomSample(std::vector* in_data, size_t out_size, G* gen) { + // Note: This function is inefficient in the cases when out_size << in_data.size() + out_size = std::min(in_data->size(), out_size); + + if (in_data->size() <= out_size) { // return all + return; + } + std::vector indices; + RandomPermutation(in_data->size(), &indices, gen); + + std::vector tmp_data; + tmp_data.reserve(out_size); + for (size_t i = 0; i < out_size; ++i) { + tmp_data.push_back(std::move((*in_data)[indices[i]])); + } + + *in_data = std::move(tmp_data); +} + +/*! \brief Argsort. Order: largest to smallest */ +template +inline void Argsort(const std::vector& scores, std::vector* index) { + index->clear(); index->reserve(scores.size()); + for (size_t i = 0; i < scores.size(); ++i) { + index->push_back(i); + } + auto cmp = [&scores](int l, int r) { + return scores[l] > scores[r]; + }; + std::sort(index->begin(), index->end(), cmp); +} + +/*! \brief Return whether a string ends with another substring */ +inline bool StrEndsWith(const std::string& a, const std::string& b) { + if (b.size() > a.size()) return false; + return std::equal(a.begin() + a.size() - b.size(), a.end(), b.begin()); +} + +/*! \brief Return whether a string starts with another substring */ +inline bool StrStartsWith(const std::string& a, const std::string& b) { + if (b.size() > a.size()) return false; + return std::equal(a.begin(), a.begin() + b.size(), b.begin()); +} + +/*! \brief Replace a sub-string to another sub-string in a string */ +inline void StrReplace(std::string* base, const std::string& from, const std::string& to) { + auto pos = base->find(from); + while (pos != std::string::npos) { + base->replace(pos, from.size(), to); + pos = base->find(from, pos + to.size()); + } +} + +/********** Utilities for TVM Containers / ByteArray **********/ +/*! \brief Compute mean of a FloatImm array */ +inline double FloatArrayMean(const Array& float_array) { + double sum = 0; + if (float_array.empty()) { + return 0.0; + } + + for (const auto&x : float_array) { + auto floatimm = x.as(); + CHECK(floatimm != nullptr); + sum += floatimm->value; + } + return sum / float_array.size(); +} + +/********** Other Utilities **********/ +/*! \brief Get an int value from an Expr */ +inline int64_t GetIntImm(const PrimExpr& expr) { + auto pint = expr.as(); + CHECK(pint != nullptr); + return pint->value; +} + +/*! \brief Compute the product of the lengths of axes */ +inline int64_t AxisLengthProd(const Array& axes) { + int64_t ret = 1.0; + for (const auto& x : axes) { + if (const IntImmNode* imm = x->dom->extent.as()) { + ret *= imm->value; + } else { + return -1.0; + } + } + return ret; +} + +/*! \brief An empty output stream */ +class NullStream : public std::ostream { + public: + NullStream() : std::ostream(nullptr) {} + NullStream(const NullStream &) : std::ostream(nullptr) {} + static NullStream& Global(); +}; + +template +NullStream& operator<<(NullStream& os, const T& value) { + return os; +} + +/*! \brief Get std cout with verbose control */ +inline std::ostream& StdCout(int verbose) { + if (verbose >= 1) { + return std::cout; + } else { + return NullStream::Global(); + } +} + +/*! \brief Print a title */ +inline void PrintTitle(const std::string& title, int verbose) { + if (verbose >= 1) { + std::cout << "------------------------------------------------------------" << "\n"; + std::cout << "----------------------- [ " << title << " ]\n"; + std::cout << "------------------------------------------------------------" << std::endl; + } +} + +/*! \brief A simple thread pool */ +class ThreadPool { + public: + void Launch(size_t n = 1) { + for (std::size_t i = 0; i < n; ++i) { + threads_.emplace_back([this] {WorkerFunc();}); + } + } + + void BeginBatch(int n) { + finish_ct_ = n; + is_finished_ = n <= 0; + } + + template::type> + std::future Enqueue(F&& f, Args&&... args) { + std::packaged_task p(std::bind(f, args...)); + + auto r = p.get_future(); + { + std::unique_lock l(m_); + work_.emplace_back(std::move(p)); + } + work_signal_.notify_one(); + return r; + } + + void WaitBatch() { + std::unique_lock l(finish_mutex_); + if (!is_finished_) { + finish_signal_.wait(l); + } + } + + void Abort() { + CancelPending(); + Join(); + } + + void CancelPending() { + std::unique_lock l(m_); + work_.clear(); + } + + void Join() { + { + std::unique_lock l(m_); + for (size_t i = 0; i < threads_.size(); ++i) { + work_.push_back({}); + } + } + work_signal_.notify_all(); + for (auto& t : threads_) { + t.join(); + } + threads_.clear(); + } + + size_t NumWorkers() { + return threads_.size(); + } + + static const int REFRESH_EVERY = 128; + static ThreadPool& Global(); + + ~ThreadPool() { + Join(); + } + + private: + void WorkerFunc() { + while (true) { + std::packaged_task f; + { + std::unique_lock l(m_); + if (work_.empty()) { + work_signal_.wait(l, [&]{ return !work_.empty(); }); + } + f = std::move(work_.front()); + work_.pop_front(); + } + if (!f.valid()) { return; } + f(); + + finish_ct_--; + if (finish_ct_ == 0) { + std::unique_lock l(finish_mutex_); + + is_finished_ = true; + finish_signal_.notify_one(); + } + } + } + + std::mutex m_; + std::condition_variable work_signal_; + std::deque> work_; + std::vector threads_; + + bool is_finished_; + std::mutex finish_mutex_; + std::atomic finish_ct_; + std::condition_variable finish_signal_; +}; + +/*! + * \brief Enumerate all possible factorization schemes for splitting an axes. + * \note This class will memorize the results for reuse. + */ +class SplitFactorizationMemo { + public: + using QueryKey = std::tuple; + + const std::vector >& GetFactorizationSchemes( + int extent, int n_lengths, int max_innermost_factor); + const std::vector& GetFactors(int n); + + private: + void DfsEnumerate(int now, int remaining_lenght, int max_innermost_factor); + + std::unordered_map > > memory_; + + int n_lengths_; + std::vector tmp_stack_; + std::vector >* results_; + std::unordered_map> factor_memory_; +}; + +} // namespace ansor +} // namespace tvm + +#endif // TVM_ANSOR_UTILS_H_ diff --git a/tests/cpp/ansor_test.cc b/tests/cpp/ansor_test.cc new file mode 100644 index 000000000000..36ac46f49551 --- /dev/null +++ b/tests/cpp/ansor_test.cc @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include + +// todo(jcf94): do not use relative path +#include "../../src/ansor/loop_state.h" + +// Compute declaration for test +tvm::Array conv2d_nchw_bn_relu_func(int N, int H, int W, + int CI, int CO, + int kernel_size, + int strides, int padding, + int dilation = 1) { + using namespace tvm; + using namespace tvm::te; + + Tensor data = placeholder({N, CI, H, W}, DataType::Float(32), "Data"); + Tensor kernel = placeholder({CO, CI, kernel_size, kernel_size}, + DataType::Float(32), "Kernel"); + Tensor bias = placeholder({CO, 1, 1}, DataType::Float(32), "Bias"); + Tensor bn_scale = placeholder({CO, 1, 1}, DataType::Float(32), "Bn_scale"); + Tensor bn_offset = placeholder({CO, 1, 1}, DataType::Float(32), "Bn_offset"); + + int OH = (H + 2 * padding - (kernel_size - 1) * dilation - 1) / strides + 1; + int OW = (W + 2 * padding - (kernel_size - 1) * dilation - 1) / strides + 1; + + const auto& conv = + topi::conv2d_nchw(data, kernel, padding, padding, strides, strides); + CHECK(conv->shape[2].as()->value == OH); + CHECK(conv->shape[3].as()->value == OW); + + const auto& bias_add = compute( + {N, CO, OH, OW}, + [&](Var i, Var j, Var k, Var l) { + return conv[i][j][k][l] + bias[j][0][0]; + }, + "Bias_add"); + const auto& bn_mul = compute( + {N, CO, OH, OW}, + [&](Var i, Var j, Var k, Var l) { + return bias_add[i][j][k][l] * bn_scale[j][0][0]; + }, + "Bn_mul"); + const auto& bn_add = compute( + {N, CO, OH, OW}, + [&](Var i, Var j, Var k, Var l) { + return bn_mul[i][j][k][l] + bn_offset[j][0][0]; + }, + "Bn_add"); + const auto& out = topi::relu(bn_add); + + return {data, kernel, bias, bn_scale, bn_offset, out}; +} + +using namespace tvm::ansor; + +// Test Access Analyzer +TEST(ComputeDAG, GetProducersConsumers) { + const auto& tensors = conv2d_nchw_bn_relu_func(1, 224, 224, 3, 64, 7, 2, 3); + const auto& dag = tvm::ansor::ComputeDAG(tensors); + int data = 0, padding = 1, kernel = 2, conv = 3, bias = 4, bias_add = 5; + int bn_scale = 6, bn_mul = 7, bn_offset = 8, bn_add = 9, relu = 10; + + State s0 = dag.GetInitState(); + std::unordered_set set; + { + std::vector> consumer_list = { + {data, padding}, {padding, conv}, {kernel, conv}, + {conv, bias_add}, {bias, bias_add}, {bias_add, bn_mul}, + {bn_scale, bn_mul}, {bn_mul, bn_add}, {bn_offset, bn_add}, + {bn_add, relu}}; + for (const auto& pair : consumer_list) { + dag->access_analyzer.GetConsumers(s0, s0->stages[pair.first]->op, &set); + CHECK_EQ(set.size(), 1); + CHECK_EQ((*set.begin()), s0->stages[pair.second]->op); + } + std::vector>> producer_list = { + {padding, {data}}, + {conv, {padding, kernel}}, + {bias_add, {conv, bias}}, + {bn_mul, {bias_add, bn_scale}}, + {bn_add, {bn_mul, bn_offset}}, + {relu, {bn_add}}}; + for (const auto& pair : producer_list) { + dag->access_analyzer.GetProducers(s0, s0->stages[pair.first]->op, &set); + CHECK_EQ(set.size(), pair.second.size()); + for (const auto& target : pair.second) { + CHECK(set.count(s0->stages[target]->op)); + } + } + } + + s0.compute_inline(bn_add); + s0.compute_inline(bn_mul); + s0.compute_inline(bias_add); + s0.compute_inline(padding); + { + std::vector> consumer_list = { + {data, conv}, {kernel, conv}, {conv, relu}}; + for (const auto& pair : consumer_list) { + dag->access_analyzer.GetConsumers(s0, s0->stages[pair.first]->op, &set); + CHECK_EQ(set.size(), 1); + CHECK_EQ((*set.begin()), s0->stages[pair.second]->op); + } + std::vector>> producer_list = { + {padding, {data}}, + {conv, {padding, kernel}}, + {bias_add, {conv, bias}}, + {bn_mul, {bias_add, bn_scale}}, + {bn_add, {bn_mul, bn_offset}}, + {relu, {bn_add}}}; + for (const auto& pair : producer_list) { + dag->access_analyzer.GetProducers(s0, s0->stages[pair.first]->op, &set); + CHECK_EQ(set.size(), pair.second.size()); + for (const auto& target : pair.second) { + CHECK(set.count(s0->stages[target]->op)); + } + } + } +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + testing::FLAGS_gtest_death_test_style = "threadsafe"; + return RUN_ALL_TESTS(); +} diff --git a/tests/python/unittest/test_ansor_common.py b/tests/python/unittest/test_ansor_common.py new file mode 100644 index 000000000000..62ebeb99a6c8 --- /dev/null +++ b/tests/python/unittest/test_ansor_common.py @@ -0,0 +1,69 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Common functions for ansor test cases""" + +from tvm import te, ansor +import topi + + +@ansor.register_workload_func +def matmul_ansor_test(N, M, K): + A = te.placeholder((N, K), name='A') + B = te.placeholder((K, M), name='B') + k = te.reduce_axis((0, K), name='k') + C = te.compute((N, M), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name='C') + return [A, B, C] + + +def conv2d_nchw_bn_relu(N, H, W, CI, CO, kernel_size, strides, padding, dilation=1): + data = te.placeholder((N, CI, H, W), name='Data') + kernel = te.placeholder((CO, CI, kernel_size, kernel_size), name='Kernel') + bias = te.placeholder((CO, 1, 1), name='Bias') + bn_scale = te.placeholder((CO, 1, 1), name='Bn_scale') + bn_offset = te.placeholder((CO, 1, 1), name='Bn_offset') + + OH = (H + 2 * padding - (kernel_size - 1) * dilation - 1) // strides + 1 + OW = (W + 2 * padding - (kernel_size - 1) * dilation - 1) // strides + 1 + + conv = topi.nn.conv2d_nchw(data, kernel, strides, padding, dilation) + conv = te.compute((N, CO, OH, OW), + lambda i, j, k, l: conv[i, j, k, l] + bias[j, 0, 0], + name='Bias_add') + conv = te.compute((N, CO, OH, OW), + lambda i, j, k, l: conv[i, j, k, l] * bn_scale[j, 0, 0], + name='Bn_mul') + conv = te.compute((N, CO, OH, OW), + lambda i, j, k, l: conv[i, j, k, l] + bn_offset[j, 0, 0], + name='Bn_add') + out = topi.nn.relu(conv) + + return [data, kernel, bias, bn_offset, bn_scale, out] + + +def get_tiled_matmul(): + A, B, C = matmul_ansor_test(512, 512, 512) + dag = ansor.ComputeDAG([A, B, C]) + + s0 = dag.get_init_state() + its0 = s0.split(C, s0[C].iters[0], [4, 8, 8]) + its1 = s0.split(C, s0[C].iters[4], [8, 4, 4]) + s0.reorder(C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], its0[3], its1[3], + s0[C].iters[8]]) + + return dag, s0 + diff --git a/tests/python/unittest/test_ansor_compute_dag.py b/tests/python/unittest/test_ansor_compute_dag.py new file mode 100644 index 000000000000..e5af07b31e0d --- /dev/null +++ b/tests/python/unittest/test_ansor_compute_dag.py @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Test ComputeDAG (replay, infer bound)""" + +import tvm +from tvm import ansor, te + +from test_ansor_common import get_tiled_matmul + + +def test_apply_steps(): + dag, s = get_tiled_matmul() + dag.print_python_code_from_state(s) + sch, tensors = dag.apply_steps_from_state(s) + stmt = tvm.lower(sch, tensors, simple_mode=True) + + +def test_infer_bound(): + dag, s = get_tiled_matmul() + s = dag.infer_bound_from_state(s) + + +def test_estimate_flop(): + dag, s = get_tiled_matmul() + + assert abs(dag.flop_ct - 2 * 512 ** 3) < 0.5 + + +if __name__ == "__main__": + test_apply_steps() + test_infer_bound() + test_estimate_flop() diff --git a/tests/python/unittest/test_ansor_loop_state.py b/tests/python/unittest/test_ansor_loop_state.py new file mode 100644 index 000000000000..35894354349f --- /dev/null +++ b/tests/python/unittest/test_ansor_loop_state.py @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Test loop state and schedule primitives""" + +import numpy as np + +import tvm +from tvm import ansor, te +import topi + +from test_ansor_common import matmul_ansor_test, conv2d_nchw_bn_relu + + +def test_split_fuse_reorder(): + A, B, C = matmul_ansor_test(512, 512, 512) + dag = ansor.ComputeDAG([A, B, C]) + s0 = dag.get_init_state() + i, j, k = s0[C].iters + + assert i.range.extent == 512 + + io, ii = s0.split(C, i, [16]) + assert s0[C].iters[0] == io + assert s0[C].iters[1] == ii + assert io.range.extent == 32 + assert ii.range.extent == 16 + + jo, ji = s0.split(C, j, [8]) + assert jo.range.extent == 64 + assert ji.range.extent == 8 + + s0.reorder(C, [io, jo, k, ji, ii]) + assert s0[C].iters[2].range.extent == 512 + + fused_it = s0.fuse(C, [io, jo]) + assert fused_it.range.extent == 2048 + + s1 = dag.get_init_state() + i, j, _ = s1[C].iters + i1, i2, i3 = s1.split(C, i, [8, 2]) + j1, j2, j3 = s1.split(C, j, [32, 8], False) + assert s1[C].iters[0].range.extent == 32 + assert s1[C].iters[1].range.extent == 8 + assert s1[C].iters[2].range.extent == 2 + assert s1[C].iters[3].range.extent == 32 + assert s1[C].iters[4].range.extent == 8 + assert s1[C].iters[5].range.extent == 2 + +if __name__ == "__main__": + test_split_fuse_reorder() diff --git a/tests/python/unittest/test_ansor_measure.py b/tests/python/unittest/test_ansor_measure.py new file mode 100644 index 000000000000..f8d41edd27dd --- /dev/null +++ b/tests/python/unittest/test_ansor_measure.py @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Test measurement and log serialization""" + +import tvm +from tvm import ansor +import tempfile + +from test_ansor_common import get_tiled_matmul + + +def test_serialization(): + dag, s = get_tiled_matmul() + target = tvm.target.create("llvm") + task = ansor.SearchTask(dag, "test", target) + + inp = ansor.measure.MeasureInput(task, s) + res = ansor.measure.MeasureResult([0.1], 0, "", 0.2, 1) + + with tempfile.NamedTemporaryFile() as fp: + ansor.serialization.write_measure_records_to_file(fp.name, [inp], [res]) + + log_reader = ansor.serialization.LogReader(fp.name) + inputs, results = log_reader.read_lines() + assert len(inputs) == 1 + + s1 = dag.infer_bound_from_state(s) + s2 = dag.infer_bound_from_state(inputs[0].state) + + assert s1 == s2 + assert not (s1 == dag.get_init_state()) + + +def test_measure_local_builder_runner(): + dag, s0 = get_tiled_matmul() + + tgt = tvm.target.create("llvm") + task = ansor.SearchTask(dag, "test", tgt) + + minp = ansor.MeasureInput(task, s0) + local_builder = ansor.LocalBuilder() + local_runner = ansor.LocalRunner() + + bress = local_builder.build([minp]) + assert bress[0].error_no == 0 + mress = local_runner.run([minp], bress) + assert mress[0].error_no == 0 + + +if __name__ == "__main__": + test_serialization() + test_measure_local_builder_runner() diff --git a/tests/python/unittest/test_ansor_search_policy.py b/tests/python/unittest/test_ansor_search_policy.py new file mode 100644 index 000000000000..b701dad6d8c0 --- /dev/null +++ b/tests/python/unittest/test_ansor_search_policy.py @@ -0,0 +1,84 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Test search policy""" + +import random +import numpy as np +import tempfile +import threading + +import tvm +from tvm import ansor + +from test_ansor_common import matmul_ansor_test + +def search_common(target="llvm", seed=random.randint(1, 1 << 30), runner='local', + cost_model=None, n_trials=2, params=None, + pre_search_callbacks=None): + print("Test %s schedule search with the default search policy" % (target)) + + random.seed(seed) + N = 128 + workload_key = ansor.make_workload_key_func(matmul_ansor_test, (N, N, N)) + dag = ansor.workload_key_to_dag(workload_key) + target = tvm.target.create(target) + task = ansor.SearchTask(dag, workload_key, target) + + with tempfile.NamedTemporaryFile() as fp: + log_file = fp.name + + search_policy = ansor.EmptyPolicy() + # search_policy = ansor.SketchSearchPolicy(cost_model, params=params, seed=seed) + tune_option = ansor.TuneOption(n_trials=n_trials, runner=runner, + measure_callbacks=[ansor.LogToFile(log_file)], + pre_search_callbacks=pre_search_callbacks) + sch, args = ansor.auto_schedule(task, search_policy=search_policy, + tune_option=tune_option) + inp, res = ansor.best_measure_pair_in_file(log_file, workload_key, target) + + print("==== Python Code ====") + print(dag.print_python_code_from_state(inp.state)) + + try: + print("==== Lowered Stmt ====") + print(tvm.lower(sch, args, simple_mode=True)) + mod = tvm.build(sch, args, target) + + ctx = tvm.context(str(target), 0) + dtype = dag.tensors[0].dtype + a = tvm.nd.array(np.random.uniform(size=(N, N)).astype(dtype), ctx) + b = tvm.nd.array(np.random.uniform(size=(N, N)).astype(dtype), ctx) + c = tvm.nd.array(np.zeros((N, N), dtype=dtype), ctx) + mod(a, b, c) + tvm.testing.assert_allclose(c.asnumpy(), np.dot( + a.asnumpy(), b.asnumpy()), rtol=1e-5) + print("==== Verification passed ====") + except Exception: + raise Exception("Error encountered with seed: %d" % (seed)) + print() + + +def test_search_basic(): + # wrap the search in a new thread to avoid the conflict + # between python's multiprocessing and tvm's thread pool + t = threading.Thread(target=search_common, kwargs={'seed': 944563397}) + t.start() + t.join() + +if __name__ == "__main__": + test_search_basic()