diff --git a/python/tvm/meta_schedule/profiler.py b/python/tvm/meta_schedule/profiler.py index 7446578a38d7..1776666f4ed5 100644 --- a/python/tvm/meta_schedule/profiler.py +++ b/python/tvm/meta_schedule/profiler.py @@ -34,7 +34,7 @@ def __init__(self) -> None: ) def get(self) -> Dict[str, float]: - """Get the profiling results in minutes""" + """Get the profiling results in seconds""" return _ffi_api.ProfilerGet(self) # type: ignore # pylint: disable=no-member def table(self) -> str: diff --git a/python/tvm/meta_schedule/testing/custom_builder_runner.py b/python/tvm/meta_schedule/testing/custom_builder_runner.py index 1cfd4ab833be..7129546dd8b7 100644 --- a/python/tvm/meta_schedule/testing/custom_builder_runner.py +++ b/python/tvm/meta_schedule/testing/custom_builder_runner.py @@ -17,7 +17,7 @@ """Customized builder and runner methods""" # pylint: disable=import-outside-toplevel -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Union, Callable if TYPE_CHECKING: import numpy as np # type: ignore @@ -143,7 +143,7 @@ def run_module_via_rpc( rpc_config: "RPCConfig", lib: Union["Module", "Executable"], dev_type: str, - args: Dict[str, "np.ndarray"], + args: Union[Dict[int, "np.ndarray"], Dict[str, "np.ndarray"]], continuation: Callable, backend: Optional[str] = "graph", ): diff --git a/python/tvm/meta_schedule/testing/tune_utils.py b/python/tvm/meta_schedule/testing/tune_utils.py index fe0984d51c50..17064c64ab52 100644 --- a/python/tvm/meta_schedule/testing/tune_utils.py +++ b/python/tvm/meta_schedule/testing/tune_utils.py @@ -86,7 +86,7 @@ def create_timer(backend: str) -> Callable: def f_timer( rt_mod: Union[tvm.runtime.Module, tvm.runtime.vm.Executable], - dev: tvm.device, + dev: tvm.runtime.Device, input_data: Dict[str, NDArray], ) -> None: """Run and benchmark the given runtime module, print out the result. @@ -95,7 +95,7 @@ def f_timer( ---------- rt_mod : Union[tvm.runtime.Module, tvm.runtime.vm.Executable] The runtime module or vm executable. - dev : tvm.device + dev : tvm.runtime.Device The device type to run workload. input_data : Dict[str, np.ndarray] The input data as a dictionary. @@ -152,7 +152,7 @@ def create_time_per_layer(graph: str) -> Callable: def f_time_per_layer( rt_mod: tvm.runtime.Module, - dev: tvm.device, + dev: tvm.runtime.Device, input_data: Dict[str, NDArray], ) -> None: """Run and benchmark the per-layer performance of given runtime module, @@ -162,7 +162,7 @@ def f_time_per_layer( ---------- rt_mod : tvm.runtime.Module The runtime module. - dev : tvm.device + dev : tvm.runtime.Device The device type to run workload. input_data : Dict[str, np.ndarray] The input data as a dictionary. @@ -192,3 +192,50 @@ def f_time_per_layer( ) return f_time_per_layer + + +def create_calculator(backend: str) -> Callable: + """Create a function to fetch the computing result of running the given runtime module. + + Parameters + ---------- + backend : str + The backend to use, only tir is supported for now. + + Returns + ------- + func : Callable + The function to fetch the computing result. + """ + + def f_calculator( + rt_mod: tvm.runtime.Module, + dev: tvm.runtime.Device, # pylint: disable=unused-argument + input_data: Dict[str, NDArray], + ) -> List[NDArray]: + """Fetch the result of running the given runtime module. + + Parameters + ---------- + rt_mod : Union[tvm.runtime.Module, tvm.runtime.vm.Executable] + The runtime module or vm executable. + dev : tvm.device + The device type to run workload. + input_data : Dict[str, np.ndarray] + The input data as a dictionary. + """ + try: + if backend == "tir": + data = [v for _, v in sorted(input_data.items(), key=lambda x: x[0])] + rt_mod(*data) + return data + else: + raise ValueError(f"Backend {backend} not supported in f_calculator!") + + except Exception as exc: # pylint: disable=broad-except + print( + f"Run module f_calculator via RPC failed, exception: {exc}", + ) + return None + + return f_calculator diff --git a/python/tvm/meta_schedule/testing/validate_database.py b/python/tvm/meta_schedule/testing/validate_database.py new file mode 100644 index 000000000000..5e48bfb6b04e --- /dev/null +++ b/python/tvm/meta_schedule/testing/validate_database.py @@ -0,0 +1,282 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""JSON Database validation script""" +from typing import Union, Callable, List +from distutils.util import strtobool +import argparse +import logging +import warnings +import numpy as np # type: ignore + +import tvm +from tvm.target import Target +from tvm.ir import IRModule +from tvm.tir import Schedule +from tvm import meta_schedule as ms +from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc +from tvm.meta_schedule.testing.tune_utils import create_calculator, generate_input_data +from tvm._ffi import get_global_func, register_func +from tvm.support import describe + +DELIMITOR = "\n" + "-" * 30 + "\n" + + +def _parse_args(): + args = argparse.ArgumentParser() + args.add_argument( + "--work-dir", + type=str, + required=True, + help="The path to the work directory containing database files.", + ) + args.add_argument( + "--target", + type=Target, + required=True, + ) + args.add_argument( + "--baseline-target", + type=Target, + default="llvm -num-cores=1", + required=False, + help="The baseline target to compile the original module.", + ) + args.add_argument( + "--rpc-host", + type=str, + required=True, + ) + args.add_argument( + "--rpc-port", + type=int, + required=True, + ) + args.add_argument( + "--rpc-key", + type=str, + required=True, + ) + args.add_argument( + "--number", + type=int, + default=3, + ) + args.add_argument( + "--repeat", + type=int, + default=1, + ) + args.add_argument( + "--min-repeat-ms", + type=int, + default=100, + ) + args.add_argument( + "--cpu-flush", + type=lambda x: bool(strtobool(x)), + help="example: True / False", + required=True, + ) + parsed = args.parse_args() + parsed.target = tvm.target.Target(parsed.target) + parsed.rpc_config = ms.runner.RPCConfig( + tracker_host=parsed.rpc_host, + tracker_port=parsed.rpc_port, + tracker_key=parsed.rpc_key, + session_timeout_sec=600, + ) + if parsed.cpu_flush and parsed.target.kind.name != "llvm": + warnings.warn("cpu_flush is only supported on llvm target") + return parsed + + +# logging +logging.basicConfig( + format="%(asctime)s.%(msecs)03d %(levelname)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S" +) +logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) + +# arg parser +ARGS = _parse_args() + + +@register_func("tvm.meta_schedule.testing.default_input_generator") +def default_input_generator(mod: IRModule) -> List[tvm.nd.NDArray]: + args_info = ms.arg_info.TensorInfo.from_prim_func(mod["main"]) + inputs = [ + tvm.nd.array(generate_input_data(input_shape=arg_info.shape, input_dtype=arg_info.dtype)) + for arg_info in args_info + ] + return inputs + + +@register_func("tvm.meta_schedule.testing.default_check_metric") +def default_check_metric(a: List[tvm.nd.NDArray], b: List[tvm.nd.NDArray]) -> bool: + assert len(a) == len(b), "Different number of outputs from two modules" + for i, _ in enumerate(a): + if not np.allclose(a[i].numpy(), b[i].numpy(), rtol=1e-3, atol=2e-3): + return False + return True + + +def validate_correctness( + original_mod: IRModule, # compiled for "baseline_target" + scheduled_mod: IRModule, # compiled for "target" + *, + baseline_target: Target, + target: Target, + dev_type: str, + rpc_config: ms.runner.RPCConfig, + f_input_generator: Union[ + str, Callable[[IRModule], List[tvm.nd.NDArray]] + ] = default_input_generator, + f_check_metric: Union[ + str, Callable[[tvm.nd.NDArray, tvm.nd.NDArray], bool] + ] = default_check_metric, +) -> bool: + """Function to validate the correctness of a scheduled module. + + Parameters + ---------- + original_mod : IRModule + The original module to be compiled. + scheduled_mod : IRModule + The scheduled module to be compiled. + baseline_target : Target + The baseline target to compile the original module. + target : Target + The target to compile the scheduled module. + dev_type : str + The device type to run the module via rpc. + rpc_config : RPCConfig + The RPCConfig to run the scheduled module. + f_input_generator : Union[str, Callable] + The function to generate the input data. + f_check_metric : Union[str, Callable] + The function to check the metric. + + Returns + ------- + result : bool + The result of the validation. + """ + + def to_numpy(a: List[tvm.nd.NDArray]) -> List[np.ndarray]: + """Convert a list of TVM NDArray to a list of numpy array""" + assert a is not None, "Empty result cannot be converted to numpy" + return [x.numpy() for x in a] + + def to_tvm_ndarray(a: List[np.ndarray]) -> List[tvm.nd.NDArray]: + """Convert a list of numpy array to a list of TVM NDArray""" + assert a is not None, "Empty result cannot be converted to TVM NDArray" + return [tvm.nd.array(x) for x in a] + + def build_and_run(mod: IRModule, target: Target, dev_type: str) -> np.ndarray: + """Build and run the module on the target device.""" + rt_mod = tvm.build(mod, target=target) + return run_module_via_rpc( + rpc_config=rpc_config, + lib=rt_mod, + dev_type=dev_type, + args={i: v for i, v in enumerate(inputs)}, # pylint: disable=unnecessary-comprehension + continuation=create_calculator(backend="tir"), + backend="tir", + ) + + # fetch functions & prepare inputs + if isinstance(f_input_generator, str): + f_input_generator = get_global_func(f_input_generator) + if isinstance(f_check_metric, str): + f_check_metric = get_global_func(f_check_metric) + inputs = to_numpy(f_input_generator(original_mod)) # type: ignore + # build & run original result + original_res = to_numpy(build_and_run(original_mod, target=baseline_target, dev_type="cpu")) + scheduled_res = to_numpy(build_and_run(scheduled_mod, target=target, dev_type=dev_type)) + # check metric + if f_check_metric(to_tvm_ndarray(original_res), to_tvm_ndarray(scheduled_res)): # type: ignore + return True + else: + print( + ("\n\n").join( + [ + "Validation failed!", + "Original Result:" + DELIMITOR + str(original_res), + "Scheduled Result:" + DELIMITOR + str(scheduled_res), + "Input:" + DELIMITOR + str(inputs), + "Original IRModule:" + DELIMITOR + original_mod.script(), + "Scheduled IRModule:" + DELIMITOR + scheduled_mod.script(), + ] + ) + ) + return False + + +def main(): + """Main function""" + describe() + database = ms.database.create(work_dir=ARGS.work_dir) + target = ARGS.target + if target.kind.name == "llvm": + dev_type = "cpu" + elif target.kind.name == "cuda": + dev_type = "cuda" + else: + raise RuntimeError(f"Unsupported target kind: {target.kind.name}") + records = database.get_all_tuning_records() + with ms.Profiler() as profiler: + for i, record in enumerate(records): + scope_name = f"validate #{i}" + with profiler.timeit(scope_name): + original_mod = record.workload.mod + sch = Schedule(original_mod) + record.trace.apply_to_schedule(sch=sch, remove_postproc=False) + scheduled_mod = sch.mod + is_success = False + try: + is_success = validate_correctness( + original_mod=original_mod, + scheduled_mod=scheduled_mod, + target=target, + baseline_target=ARGS.baseline_target, + dev_type=dev_type, + rpc_config=ARGS.rpc_config, + ) + except Exception as e: # pylint: disable=broad-except, invalid-name + print( + ("\n\n").join( + [ + "Validation failed!", + "Original IRModule:" + DELIMITOR + original_mod.script(), + "Scheduled IRModule:" + DELIMITOR + scheduled_mod.script(), + "Exception" + DELIMITOR + str(e), + ] + ) + ) + if is_success: + print( + f"Progress {i+1: 6d} / {len(records): 6d} checked," + f" used {float(profiler.get()[scope_name]): 3.3f} sec." + ) + else: + return + + print("Validation passed!") + print(f"Total time spent: {float(profiler.get()['Total']): 3.3f} sec.") + + +if __name__ == "__main__": + main()