From a234ac42a4aa67b47cfca6c1b886bccf978dc975 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Thu, 29 Sep 2022 14:52:55 -0700 Subject: [PATCH 01/12] Add validation scripts. --- .../tvm/meta_schedule/testing/tune_utils.py | 47 ++++ .../testing/validate_database.py | 256 ++++++++++++++++++ 2 files changed, 303 insertions(+) create mode 100644 python/tvm/meta_schedule/testing/validate_database.py diff --git a/python/tvm/meta_schedule/testing/tune_utils.py b/python/tvm/meta_schedule/testing/tune_utils.py index fe0984d51c50..ba6e9c3de6d0 100644 --- a/python/tvm/meta_schedule/testing/tune_utils.py +++ b/python/tvm/meta_schedule/testing/tune_utils.py @@ -192,3 +192,50 @@ def f_time_per_layer( ) return f_time_per_layer + + +def create_computer(backend: str) -> Callable: + """Create a function to fetch the computing result of running the given runtime module. + + Parameters + ---------- + backend : str + The backend to use, graph / vm. + + Returns + ------- + func : Callable + The function to fetch the computing result. + """ + + def f_computer( + rt_mod: tvm.runtime.Module, + dev: tvm.device, + input_data: Dict[str, NDArray], + ) -> None: + """Fetch the result of running the given runtime module. + + Parameters + ---------- + rt_mod : Union[tvm.runtime.Module, tvm.runtime.vm.Executable] + The runtime module or vm executable. + dev : tvm.device + The device type to run workload. + input_data : Dict[str, np.ndarray] + The input data as a dictionary. + """ + try: + if backend == "tir": + + inputs = [lambda x: x[1] for x in sorted(lambda x: x[0], input_data.items())] + rt_mod["default"](dev)(inputs) + return [x.cpu().numpy() for x in inputs] + else: + raise ValueError(f"Backend {backend} not supported in f_timer!") + + except Exception as exc: # pylint: disable=broad-except + print( + f"Run module f_computer via RPC failed, exception: {exc}", + ) + + return f_computer diff --git a/python/tvm/meta_schedule/testing/validate_database.py b/python/tvm/meta_schedule/testing/validate_database.py new file mode 100644 index 000000000000..6078a87e0482 --- /dev/null +++ b/python/tvm/meta_schedule/testing/validate_database.py @@ -0,0 +1,256 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""JSON Database validation script""" +from typing import Union, Callable, List +from distutils.util import strtobool +import argparse +import logging +import warnings +import numpy as np + +import tvm +from tvm.target import Target +from tvm.ir import IRModule +from tvm.tir import Schedule +from tvm import meta_schedule as ms +from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc +from tvm.meta_schedule.testing.tune_utils import create_computer, generate_input_data +from tvm._ffi import get_global_func, register_func +from tvm.support import describe + + +def _parse_args(): + args = argparse.ArgumentParser() + args.add_argument( + "--path-workload", + type=str, + required=True, + help="The path to the database workload file.", + ) + args.add_argument( + "--path-tuning-record", + type=str, + required=True, + help="The path to the database tuning record file.", + ) + args.add_argument( + "--target", + type=str, + required=True, + ) + args.add_argument( + "--baseline-target", + type=str, + default="llvm -num-cores=1", + required=False, + help="The baseline target to compile the original module.", + ) + args.add_argument( + "--rpc-host", + type=str, + required=True, + ) + args.add_argument( + "--rpc-port", + type=int, + required=True, + ) + args.add_argument( + "--rpc-key", + type=str, + required=True, + ) + args.add_argument( + "--number", + type=int, + default=3, + ) + args.add_argument( + "--repeat", + type=int, + default=1, + ) + args.add_argument( + "--min-repeat-ms", + type=int, + default=100, + ) + args.add_argument( + "--cpu-flush", + type=lambda x: bool(strtobool(x)), + help="example: True / False", + required=True, + ) + parsed = args.parse_args() + parsed.target = tvm.target.Target(parsed.target) + parsed.rpc_config = ms.runner.RPCConfig( + tracker_host=parsed.rpc_host, + tracker_port=parsed.rpc_port, + tracker_key=parsed.rpc_key, + session_timeout_sec=600, + ) + if parsed.cpu_flush and parsed.target.kind.name != "llvm": + warnings.warn("cpu_flush is only supported on llvm target") + return parsed + + +# logging +logging.basicConfig( + format="%(asctime)s.%(msecs)03d %(levelname)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S" +) +logging.getLogger("tvm.meta_schedule").setLevel(logging.INFO) + +# arg parser +ARGS = _parse_args() + + +@register_func("tvm.meta_schedule.testing.default_input_generator") +def default_input_generator(mod: IRModule) -> List[np.ndarray]: + args_info = ms.arg_info.TensorInfo.from_prim_func(mod["main"]) + inputs = [ + generate_input_data(input_shape=arg_info.shape, input_dtype=arg_info.dtype) + for arg_info in args_info + ] + return inputs + + +@register_func("tvm.meta_schedule.testing.default_check_metric") +def default_check_metric(a: np.ndarray, b: np.ndarray) -> bool: + return np.allclose(a, b, rtol=1e-3, atol=2e-3) + + +def validate_correctness( + original_mod: IRModule, # compiled for "baseline_target" + scheduled_mod: IRModule, # compiled for "target" + *, + baseline_target: Union[str, Target], + target: Union[str, Target], + dev_type: str, + rpc_config: ms.runner.RPCConfig, + f_input_generator: Union[str, Callable] = "tvm.meta_schedule.testing.default_input_generator", + f_check_metric: Union[str, Callable] = "tvm.meta_schedule.testing.default_check_metric", +) -> bool: + """Function to validate the correctness of a scheduled module. + + Parameters + ---------- + original_mod : IRModule + The original module to be compiled. + scheduled_mod : IRModule + The scheduled module to be compiled. + target : Target + The target to compile the scheduled module. + rpc_config : RPCConfig + The RPCConfig to run the scheduled module. + f_input_generator : Union[str, Callable] + The function to generate the input data. + f_check_metric : Union[str, Callable] + The function to check the metric. + + Returns + ------- + result : ... + The result of the validation. + """ + + def build_and_run(mod: IRModule, target: Target, dev_type: str) -> np.ndarray: + """Build and run the module on the target device.""" + rt_mod = tvm.build(mod, target=target) + args = {i: arg for i, arg in enumerate(inputs)} + return run_module_via_rpc( + rpc_config=rpc_config, + lib=rt_mod, + dev_type=dev_type, + args=args, + continuation=create_computer(backend="tir"), + backend="tir", + ) + + # make targets + target = Target(target) + baseline_target = Target(baseline_target) + # fetch functions & prepare inputs + if isinstance(f_input_generator, str): + input_generator_func = get_global_func(f_input_generator) + if isinstance(f_check_metric, str): + check_metric_func = get_global_func(f_check_metric) + inputs = input_generator_func(original_mod) + # build & run original result + original_res = build_and_run(original_mod, target=baseline_target, dev_type="cpu") + scheduled_res = build_and_run(scheduled_mod, target=target, dev_type=dev_type) + # check metric + if not check_metric_func(original_res, scheduled_res): + return True + else: + print( + ("\n\n").join( + [ + "Validation failed!", + "Original Result:\n" + "-" * 10 + str(original_res), + "Scheduled Result:\n" + "-" * 10 + str(scheduled_res), + "Input:\n" + "-" * 10 + str(inputs), + "Original IRModule:\n" + "-" * 10 + original_mod.script(), + "Scheduled IRModule:\n" + "-" * 10 + scheduled_mod.script(), + ] + ) + ) + return False + + +def main(): + """Main function""" + describe() + database = ms.database.JSONDatabase( + path_workload=ARGS.path_workload, path_tuning_record=ARGS.path_tuning_record + ) + assert Target(ARGS.target).kind.name in ["llvm", "cuda"] + dev_type = "cpu" if Target(ARGS.target).kind.name == "llvm" else "cuda" + records = database.get_all_tuning_records() + for i, record in enumerate(records): + original_mod = record.workload.mod + sch = Schedule(original_mod) + scheduled_mod = record.trace.apply_to_schedule(sch=sch, remove_postproc=False) + try: + flag = validate_correctness( + original_mod=original_mod, + scheduled_mod=scheduled_mod, + target=ARGS.target, + baseline_target=ARGS.baseline_target, + dev_type=dev_type, + rpc_config=ARGS.rpc_config, + ) + except Exception as e: # pylint: disable=broad-except, invalid-name + print( + ("\n\n").join( + [ + "Validation failed!", + "Original IRModule:\n" + "-" * 10 + original_mod.script(), + "Scheduled IRModule:\n" + "-" * 10 + scheduled_mod.script(), + "Exception\n" + "-" * 10 + str(e), + ] + ) + ) + if flag: + print(f"Progress {i+1: 6d} / {len(records): 6d} checked.") + else: + return + + print("Validation passed!") + + +if __name__ == "__main__": + main() From 22d955b8049d8de47f27a3c081060aaa3db5140f Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Fri, 30 Sep 2022 15:12:18 -0700 Subject: [PATCH 02/12] Fix testing script. --- .../tvm/meta_schedule/testing/tune_utils.py | 19 +++--- .../testing/validate_database.py | 64 +++++++++++++------ 2 files changed, 52 insertions(+), 31 deletions(-) diff --git a/python/tvm/meta_schedule/testing/tune_utils.py b/python/tvm/meta_schedule/testing/tune_utils.py index ba6e9c3de6d0..a42d1021e454 100644 --- a/python/tvm/meta_schedule/testing/tune_utils.py +++ b/python/tvm/meta_schedule/testing/tune_utils.py @@ -86,7 +86,7 @@ def create_timer(backend: str) -> Callable: def f_timer( rt_mod: Union[tvm.runtime.Module, tvm.runtime.vm.Executable], - dev: tvm.device, + dev: tvm.runtime.Device, input_data: Dict[str, NDArray], ) -> None: """Run and benchmark the given runtime module, print out the result. @@ -95,7 +95,7 @@ def f_timer( ---------- rt_mod : Union[tvm.runtime.Module, tvm.runtime.vm.Executable] The runtime module or vm executable. - dev : tvm.device + dev : tvm.runtime.Device The device type to run workload. input_data : Dict[str, np.ndarray] The input data as a dictionary. @@ -152,7 +152,7 @@ def create_time_per_layer(graph: str) -> Callable: def f_time_per_layer( rt_mod: tvm.runtime.Module, - dev: tvm.device, + dev: tvm.runtime.Device, input_data: Dict[str, NDArray], ) -> None: """Run and benchmark the per-layer performance of given runtime module, @@ -162,7 +162,7 @@ def f_time_per_layer( ---------- rt_mod : tvm.runtime.Module The runtime module. - dev : tvm.device + dev : tvm.runtime.Device The device type to run workload. input_data : Dict[str, np.ndarray] The input data as a dictionary. @@ -210,7 +210,7 @@ def create_computer(backend: str) -> Callable: def f_computer( rt_mod: tvm.runtime.Module, - dev: tvm.device, + dev: tvm.runtime.Device, input_data: Dict[str, NDArray], ) -> None: """Fetch the result of running the given runtime module. @@ -226,12 +226,11 @@ def f_computer( """ try: if backend == "tir": - - inputs = [lambda x: x[1] for x in sorted(lambda x: x[0], input_data.items())] - rt_mod["default"](dev)(inputs) - return [x.cpu().numpy() for x in inputs] + inputs = [v for _, v in sorted(input_data.items(), key=lambda x: x[0])] + rt_mod(*inputs) + return inputs else: - raise ValueError(f"Backend {backend} not supported in f_timer!") + raise ValueError(f"Backend {backend} not supported in f_computer!") except Exception as exc: # pylint: disable=broad-except print( diff --git a/python/tvm/meta_schedule/testing/validate_database.py b/python/tvm/meta_schedule/testing/validate_database.py index 6078a87e0482..342f3235b628 100644 --- a/python/tvm/meta_schedule/testing/validate_database.py +++ b/python/tvm/meta_schedule/testing/validate_database.py @@ -32,6 +32,8 @@ from tvm._ffi import get_global_func, register_func from tvm.support import describe +DELIMITOR = "\n" + "-" * 30 + "\n" + def _parse_args(): args = argparse.ArgumentParser() @@ -119,18 +121,22 @@ def _parse_args(): @register_func("tvm.meta_schedule.testing.default_input_generator") -def default_input_generator(mod: IRModule) -> List[np.ndarray]: +def default_input_generator(mod: IRModule) -> List[tvm.nd.NDArray]: args_info = ms.arg_info.TensorInfo.from_prim_func(mod["main"]) inputs = [ - generate_input_data(input_shape=arg_info.shape, input_dtype=arg_info.dtype) + tvm.nd.array(generate_input_data(input_shape=arg_info.shape, input_dtype=arg_info.dtype)) for arg_info in args_info ] return inputs @register_func("tvm.meta_schedule.testing.default_check_metric") -def default_check_metric(a: np.ndarray, b: np.ndarray) -> bool: - return np.allclose(a, b, rtol=1e-3, atol=2e-3) +def default_check_metric(a: List[tvm.nd.NDArray], b: List[tvm.nd.NDArray]) -> bool: + assert len(a) == len(b), "Different number of outputs from two modules" + for i, _ in enumerate(a): + if not np.allclose(a[i].numpy(), b[i].numpy(), rtol=1e-3, atol=2e-3): + return False + return True def validate_correctness( @@ -141,8 +147,12 @@ def validate_correctness( target: Union[str, Target], dev_type: str, rpc_config: ms.runner.RPCConfig, - f_input_generator: Union[str, Callable] = "tvm.meta_schedule.testing.default_input_generator", - f_check_metric: Union[str, Callable] = "tvm.meta_schedule.testing.default_check_metric", + f_input_generator: Union[ + str, Callable[[IRModule], List[tvm.nd.NDArray]] + ] = default_input_generator, + f_check_metric: Union[ + str, Callable[[tvm.nd.NDArray, tvm.nd.NDArray], bool] + ] = default_check_metric, ) -> bool: """Function to validate the correctness of a scheduled module. @@ -167,6 +177,16 @@ def validate_correctness( The result of the validation. """ + def to_numpy(a: List[tvm.nd.NDArray]) -> List[np.ndarray]: + """Convert a list of TVM NDArray to a list of numpy array""" + assert a is not None, "Empty result cannot be converted to numpy" + return [x.numpy() for x in a] + + def to_tvm_ndarray(a: List[np.ndarray]) -> List[tvm.nd.NDArray]: + """Convert a list of numpy array to a list of TVM NDArray""" + assert a is not None, "Empty result cannot be converted to TVM NDArray" + return [tvm.nd.array(x) for x in a] + def build_and_run(mod: IRModule, target: Target, dev_type: str) -> np.ndarray: """Build and run the module on the target device.""" rt_mod = tvm.build(mod, target=target) @@ -185,26 +205,26 @@ def build_and_run(mod: IRModule, target: Target, dev_type: str) -> np.ndarray: baseline_target = Target(baseline_target) # fetch functions & prepare inputs if isinstance(f_input_generator, str): - input_generator_func = get_global_func(f_input_generator) + f_input_generator = get_global_func(f_input_generator) if isinstance(f_check_metric, str): - check_metric_func = get_global_func(f_check_metric) - inputs = input_generator_func(original_mod) + f_check_metric = get_global_func(f_check_metric) + inputs = to_numpy(f_input_generator(original_mod)) # build & run original result - original_res = build_and_run(original_mod, target=baseline_target, dev_type="cpu") - scheduled_res = build_and_run(scheduled_mod, target=target, dev_type=dev_type) + original_res = to_numpy(build_and_run(original_mod, target=baseline_target, dev_type="cpu")) + scheduled_res = to_numpy(build_and_run(scheduled_mod, target=target, dev_type=dev_type)) # check metric - if not check_metric_func(original_res, scheduled_res): + if f_check_metric(to_tvm_ndarray(original_res), to_tvm_ndarray(scheduled_res)): return True else: print( ("\n\n").join( [ "Validation failed!", - "Original Result:\n" + "-" * 10 + str(original_res), - "Scheduled Result:\n" + "-" * 10 + str(scheduled_res), - "Input:\n" + "-" * 10 + str(inputs), - "Original IRModule:\n" + "-" * 10 + original_mod.script(), - "Scheduled IRModule:\n" + "-" * 10 + scheduled_mod.script(), + "Original Result:" + DELIMITOR + str(original_res), + "Scheduled Result:" + DELIMITOR + str(scheduled_res), + "Input:" + DELIMITOR + str(inputs), + "Original IRModule:" + DELIMITOR + original_mod.script(), + "Scheduled IRModule:" + DELIMITOR + scheduled_mod.script(), ] ) ) @@ -223,7 +243,9 @@ def main(): for i, record in enumerate(records): original_mod = record.workload.mod sch = Schedule(original_mod) - scheduled_mod = record.trace.apply_to_schedule(sch=sch, remove_postproc=False) + record.trace.apply_to_schedule(sch=sch, remove_postproc=False) + scheduled_mod = sch.mod + flag = False try: flag = validate_correctness( original_mod=original_mod, @@ -238,9 +260,9 @@ def main(): ("\n\n").join( [ "Validation failed!", - "Original IRModule:\n" + "-" * 10 + original_mod.script(), - "Scheduled IRModule:\n" + "-" * 10 + scheduled_mod.script(), - "Exception\n" + "-" * 10 + str(e), + "Original IRModule:" + DELIMITOR + original_mod.script(), + "Scheduled IRModule:" + DELIMITOR + scheduled_mod.script(), + "Exception" + DELIMITOR + str(e), ] ) ) From c81ef1ce6e802702aca4584cbd8293b2d63b9fc2 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Fri, 30 Sep 2022 15:45:04 -0700 Subject: [PATCH 03/12] Fix lint. --- .../tvm/meta_schedule/testing/custom_builder_runner.py | 2 +- python/tvm/meta_schedule/testing/tune_utils.py | 10 +++++----- python/tvm/meta_schedule/testing/validate_database.py | 9 ++++----- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/python/tvm/meta_schedule/testing/custom_builder_runner.py b/python/tvm/meta_schedule/testing/custom_builder_runner.py index 1cfd4ab833be..2cab51ce6a8d 100644 --- a/python/tvm/meta_schedule/testing/custom_builder_runner.py +++ b/python/tvm/meta_schedule/testing/custom_builder_runner.py @@ -143,7 +143,7 @@ def run_module_via_rpc( rpc_config: "RPCConfig", lib: Union["Module", "Executable"], dev_type: str, - args: Dict[str, "np.ndarray"], + args: Dict[Union[str, int], "np.ndarray"], continuation: Callable, backend: Optional[str] = "graph", ): diff --git a/python/tvm/meta_schedule/testing/tune_utils.py b/python/tvm/meta_schedule/testing/tune_utils.py index a42d1021e454..988b577b6c9a 100644 --- a/python/tvm/meta_schedule/testing/tune_utils.py +++ b/python/tvm/meta_schedule/testing/tune_utils.py @@ -210,9 +210,9 @@ def create_computer(backend: str) -> Callable: def f_computer( rt_mod: tvm.runtime.Module, - dev: tvm.runtime.Device, + dev: tvm.runtime.Device, # pylint: disable=unused-argument input_data: Dict[str, NDArray], - ) -> None: + ) -> List[NDArray]: """Fetch the result of running the given runtime module. Parameters @@ -226,9 +226,9 @@ def f_computer( """ try: if backend == "tir": - inputs = [v for _, v in sorted(input_data.items(), key=lambda x: x[0])] - rt_mod(*inputs) - return inputs + data = [v for _, v in sorted(input_data.items(), key=lambda x: x[0])] + rt_mod(*data) + return data else: raise ValueError(f"Backend {backend} not supported in f_computer!") diff --git a/python/tvm/meta_schedule/testing/validate_database.py b/python/tvm/meta_schedule/testing/validate_database.py index 342f3235b628..336a407d30c5 100644 --- a/python/tvm/meta_schedule/testing/validate_database.py +++ b/python/tvm/meta_schedule/testing/validate_database.py @@ -20,7 +20,7 @@ import argparse import logging import warnings -import numpy as np +import numpy as np # type: ignore import tvm from tvm.target import Target @@ -190,12 +190,11 @@ def to_tvm_ndarray(a: List[np.ndarray]) -> List[tvm.nd.NDArray]: def build_and_run(mod: IRModule, target: Target, dev_type: str) -> np.ndarray: """Build and run the module on the target device.""" rt_mod = tvm.build(mod, target=target) - args = {i: arg for i, arg in enumerate(inputs)} return run_module_via_rpc( rpc_config=rpc_config, lib=rt_mod, dev_type=dev_type, - args=args, + args=inputs, continuation=create_computer(backend="tir"), backend="tir", ) @@ -208,12 +207,12 @@ def build_and_run(mod: IRModule, target: Target, dev_type: str) -> np.ndarray: f_input_generator = get_global_func(f_input_generator) if isinstance(f_check_metric, str): f_check_metric = get_global_func(f_check_metric) - inputs = to_numpy(f_input_generator(original_mod)) + inputs = to_numpy(f_input_generator(original_mod)) # type: ignore # build & run original result original_res = to_numpy(build_and_run(original_mod, target=baseline_target, dev_type="cpu")) scheduled_res = to_numpy(build_and_run(scheduled_mod, target=target, dev_type=dev_type)) # check metric - if f_check_metric(to_tvm_ndarray(original_res), to_tvm_ndarray(scheduled_res)): + if f_check_metric(to_tvm_ndarray(original_res), to_tvm_ndarray(scheduled_res)): # type: ignore return True else: print( From dc784972eb627a5a8a9901d8ed227a242900d7a6 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Mon, 3 Oct 2022 13:45:29 -0700 Subject: [PATCH 04/12] Fix lint. --- python/tvm/meta_schedule/testing/custom_builder_runner.py | 4 ++-- python/tvm/meta_schedule/testing/tune_utils.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/python/tvm/meta_schedule/testing/custom_builder_runner.py b/python/tvm/meta_schedule/testing/custom_builder_runner.py index 2cab51ce6a8d..d3d7596d9f51 100644 --- a/python/tvm/meta_schedule/testing/custom_builder_runner.py +++ b/python/tvm/meta_schedule/testing/custom_builder_runner.py @@ -17,7 +17,7 @@ """Customized builder and runner methods""" # pylint: disable=import-outside-toplevel -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Dict, List, Any, Optional, Union, Callable if TYPE_CHECKING: import numpy as np # type: ignore @@ -143,7 +143,7 @@ def run_module_via_rpc( rpc_config: "RPCConfig", lib: Union["Module", "Executable"], dev_type: str, - args: Dict[Union[str, int], "np.ndarray"], + args: Dict[Any, "np.ndarray"], continuation: Callable, backend: Optional[str] = "graph", ): diff --git a/python/tvm/meta_schedule/testing/tune_utils.py b/python/tvm/meta_schedule/testing/tune_utils.py index 988b577b6c9a..f0844a5ef11e 100644 --- a/python/tvm/meta_schedule/testing/tune_utils.py +++ b/python/tvm/meta_schedule/testing/tune_utils.py @@ -236,5 +236,6 @@ def f_computer( print( f"Run module f_computer via RPC failed, exception: {exc}", ) + return None return f_computer From 7463a055e9a71908970f6ac2276b49f57affe905 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Mon, 3 Oct 2022 16:00:44 -0700 Subject: [PATCH 05/12] Fix inputs. --- python/tvm/meta_schedule/testing/validate_database.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/meta_schedule/testing/validate_database.py b/python/tvm/meta_schedule/testing/validate_database.py index 336a407d30c5..bf869661d367 100644 --- a/python/tvm/meta_schedule/testing/validate_database.py +++ b/python/tvm/meta_schedule/testing/validate_database.py @@ -194,7 +194,7 @@ def build_and_run(mod: IRModule, target: Target, dev_type: str) -> np.ndarray: rpc_config=rpc_config, lib=rt_mod, dev_type=dev_type, - args=inputs, + args={i: v for i, v in enumerate(inputs)}, continuation=create_computer(backend="tir"), backend="tir", ) From aa66c3013849a0ccf9e0108edd7a18d01060dcfe Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Mon, 3 Oct 2022 16:46:04 -0700 Subject: [PATCH 06/12] Fix lint. --- python/tvm/meta_schedule/testing/validate_database.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/meta_schedule/testing/validate_database.py b/python/tvm/meta_schedule/testing/validate_database.py index bf869661d367..c72b015d8286 100644 --- a/python/tvm/meta_schedule/testing/validate_database.py +++ b/python/tvm/meta_schedule/testing/validate_database.py @@ -194,7 +194,7 @@ def build_and_run(mod: IRModule, target: Target, dev_type: str) -> np.ndarray: rpc_config=rpc_config, lib=rt_mod, dev_type=dev_type, - args={i: v for i, v in enumerate(inputs)}, + args={i: v for i, v in enumerate(inputs)}, # type: ignore continuation=create_computer(backend="tir"), backend="tir", ) From fd3a4a0fa044a3a9575c2a21041ea34d83cffb31 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Mon, 3 Oct 2022 17:15:20 -0700 Subject: [PATCH 07/12] Fix lint. --- python/tvm/meta_schedule/testing/validate_database.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/meta_schedule/testing/validate_database.py b/python/tvm/meta_schedule/testing/validate_database.py index c72b015d8286..0001ddd1989b 100644 --- a/python/tvm/meta_schedule/testing/validate_database.py +++ b/python/tvm/meta_schedule/testing/validate_database.py @@ -194,7 +194,7 @@ def build_and_run(mod: IRModule, target: Target, dev_type: str) -> np.ndarray: rpc_config=rpc_config, lib=rt_mod, dev_type=dev_type, - args={i: v for i, v in enumerate(inputs)}, # type: ignore + args={i: v for i, v in enumerate(inputs)}, # pylint: disable=unnecessary-comprehension continuation=create_computer(backend="tir"), backend="tir", ) From 36809d628e76c4529286dab06014b750e41d82f1 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Tue, 11 Oct 2022 16:21:41 -0700 Subject: [PATCH 08/12] Add timer func. --- python/tvm/meta_schedule/profiler.py | 2 +- .../testing/validate_database.py | 63 ++++++++++--------- 2 files changed, 35 insertions(+), 30 deletions(-) diff --git a/python/tvm/meta_schedule/profiler.py b/python/tvm/meta_schedule/profiler.py index 7446578a38d7..1776666f4ed5 100644 --- a/python/tvm/meta_schedule/profiler.py +++ b/python/tvm/meta_schedule/profiler.py @@ -34,7 +34,7 @@ def __init__(self) -> None: ) def get(self) -> Dict[str, float]: - """Get the profiling results in minutes""" + """Get the profiling results in seconds""" return _ffi_api.ProfilerGet(self) # type: ignore # pylint: disable=no-member def table(self) -> str: diff --git a/python/tvm/meta_schedule/testing/validate_database.py b/python/tvm/meta_schedule/testing/validate_database.py index 0001ddd1989b..0568fe387919 100644 --- a/python/tvm/meta_schedule/testing/validate_database.py +++ b/python/tvm/meta_schedule/testing/validate_database.py @@ -239,36 +239,41 @@ def main(): assert Target(ARGS.target).kind.name in ["llvm", "cuda"] dev_type = "cpu" if Target(ARGS.target).kind.name == "llvm" else "cuda" records = database.get_all_tuning_records() - for i, record in enumerate(records): - original_mod = record.workload.mod - sch = Schedule(original_mod) - record.trace.apply_to_schedule(sch=sch, remove_postproc=False) - scheduled_mod = sch.mod - flag = False - try: - flag = validate_correctness( - original_mod=original_mod, - scheduled_mod=scheduled_mod, - target=ARGS.target, - baseline_target=ARGS.baseline_target, - dev_type=dev_type, - rpc_config=ARGS.rpc_config, - ) - except Exception as e: # pylint: disable=broad-except, invalid-name - print( - ("\n\n").join( - [ - "Validation failed!", - "Original IRModule:" + DELIMITOR + original_mod.script(), - "Scheduled IRModule:" + DELIMITOR + scheduled_mod.script(), - "Exception" + DELIMITOR + str(e), - ] + with ms.Profiler() as profiler: + for i, record in enumerate(records): + scope_name = f"validate #{i}" + with profiler.timeit(scope_name): + original_mod = record.workload.mod + sch = Schedule(original_mod) + record.trace.apply_to_schedule(sch=sch, remove_postproc=False) + scheduled_mod = sch.mod + flag = False + try: + flag = validate_correctness( + original_mod=original_mod, + scheduled_mod=scheduled_mod, + target=ARGS.target, + baseline_target=ARGS.baseline_target, + dev_type=dev_type, + rpc_config=ARGS.rpc_config, + ) + except Exception as e: # pylint: disable=broad-except, invalid-name + print( + ("\n\n").join( + [ + "Validation failed!", + "Original IRModule:" + DELIMITOR + original_mod.script(), + "Scheduled IRModule:" + DELIMITOR + scheduled_mod.script(), + "Exception" + DELIMITOR + str(e), + ] + ) + ) + if flag: + print( + f"Progress {i+1: 6d} / {len(records): 6d} checked, used {float(profiler.get()[scope_name]): 3.3f} sec." ) - ) - if flag: - print(f"Progress {i+1: 6d} / {len(records): 6d} checked.") - else: - return + else: + return print("Validation passed!") From bf6ea0c69f28b139628aff51744fb2bb90020043 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Tue, 11 Oct 2022 16:50:24 -0700 Subject: [PATCH 09/12] Fix ci. --- python/tvm/meta_schedule/testing/validate_database.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/meta_schedule/testing/validate_database.py b/python/tvm/meta_schedule/testing/validate_database.py index 0568fe387919..703cc87b1db2 100644 --- a/python/tvm/meta_schedule/testing/validate_database.py +++ b/python/tvm/meta_schedule/testing/validate_database.py @@ -270,7 +270,8 @@ def main(): ) if flag: print( - f"Progress {i+1: 6d} / {len(records): 6d} checked, used {float(profiler.get()[scope_name]): 3.3f} sec." + f"Progress {i+1: 6d} / {len(records): 6d} checked," + f" used {float(profiler.get()[scope_name]): 3.3f} sec." ) else: return From b0b28cb70f9947ef8d0672c1d414034a42d717d2 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Wed, 9 Nov 2022 11:26:54 -0800 Subject: [PATCH 10/12] Address comments. --- .../testing/custom_builder_runner.py | 4 +- .../tvm/meta_schedule/testing/tune_utils.py | 12 ++--- .../testing/validate_database.py | 54 +++++++++---------- 3 files changed, 34 insertions(+), 36 deletions(-) diff --git a/python/tvm/meta_schedule/testing/custom_builder_runner.py b/python/tvm/meta_schedule/testing/custom_builder_runner.py index d3d7596d9f51..63c3c57e5cd1 100644 --- a/python/tvm/meta_schedule/testing/custom_builder_runner.py +++ b/python/tvm/meta_schedule/testing/custom_builder_runner.py @@ -17,7 +17,7 @@ """Customized builder and runner methods""" # pylint: disable=import-outside-toplevel -from typing import TYPE_CHECKING, Dict, List, Any, Optional, Union, Callable +from typing import TYPE_CHECKING, Dict, List, Optional, Union, Callable if TYPE_CHECKING: import numpy as np # type: ignore @@ -143,7 +143,7 @@ def run_module_via_rpc( rpc_config: "RPCConfig", lib: Union["Module", "Executable"], dev_type: str, - args: Dict[Any, "np.ndarray"], + args: Dict[Union[int, str], "np.ndarray"], continuation: Callable, backend: Optional[str] = "graph", ): diff --git a/python/tvm/meta_schedule/testing/tune_utils.py b/python/tvm/meta_schedule/testing/tune_utils.py index f0844a5ef11e..17064c64ab52 100644 --- a/python/tvm/meta_schedule/testing/tune_utils.py +++ b/python/tvm/meta_schedule/testing/tune_utils.py @@ -194,13 +194,13 @@ def f_time_per_layer( return f_time_per_layer -def create_computer(backend: str) -> Callable: +def create_calculator(backend: str) -> Callable: """Create a function to fetch the computing result of running the given runtime module. Parameters ---------- backend : str - The backend to use, graph / vm. + The backend to use, only tir is supported for now. Returns ------- @@ -208,7 +208,7 @@ def create_computer(backend: str) -> Callable: The function to fetch the computing result. """ - def f_computer( + def f_calculator( rt_mod: tvm.runtime.Module, dev: tvm.runtime.Device, # pylint: disable=unused-argument input_data: Dict[str, NDArray], @@ -230,12 +230,12 @@ def f_computer( rt_mod(*data) return data else: - raise ValueError(f"Backend {backend} not supported in f_computer!") + raise ValueError(f"Backend {backend} not supported in f_calculator!") except Exception as exc: # pylint: disable=broad-except print( - f"Run module f_computer via RPC failed, exception: {exc}", + f"Run module f_calculator via RPC failed, exception: {exc}", ) return None - return f_computer + return f_calculator diff --git a/python/tvm/meta_schedule/testing/validate_database.py b/python/tvm/meta_schedule/testing/validate_database.py index 703cc87b1db2..b4beda461eec 100644 --- a/python/tvm/meta_schedule/testing/validate_database.py +++ b/python/tvm/meta_schedule/testing/validate_database.py @@ -28,7 +28,7 @@ from tvm.tir import Schedule from tvm import meta_schedule as ms from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc -from tvm.meta_schedule.testing.tune_utils import create_computer, generate_input_data +from tvm.meta_schedule.testing.tune_utils import create_calculator, generate_input_data from tvm._ffi import get_global_func, register_func from tvm.support import describe @@ -38,25 +38,19 @@ def _parse_args(): args = argparse.ArgumentParser() args.add_argument( - "--path-workload", + "--work-dir", type=str, required=True, - help="The path to the database workload file.", - ) - args.add_argument( - "--path-tuning-record", - type=str, - required=True, - help="The path to the database tuning record file.", + help="The path to the work directory containing database files.", ) args.add_argument( "--target", - type=str, + type=Target, required=True, ) args.add_argument( "--baseline-target", - type=str, + type=Target, default="llvm -num-cores=1", required=False, help="The baseline target to compile the original module.", @@ -114,7 +108,7 @@ def _parse_args(): logging.basicConfig( format="%(asctime)s.%(msecs)03d %(levelname)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S" ) -logging.getLogger("tvm.meta_schedule").setLevel(logging.INFO) +logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) # arg parser ARGS = _parse_args() @@ -143,8 +137,8 @@ def validate_correctness( original_mod: IRModule, # compiled for "baseline_target" scheduled_mod: IRModule, # compiled for "target" *, - baseline_target: Union[str, Target], - target: Union[str, Target], + baseline_target: Target, + target: Target, dev_type: str, rpc_config: ms.runner.RPCConfig, f_input_generator: Union[ @@ -162,8 +156,12 @@ def validate_correctness( The original module to be compiled. scheduled_mod : IRModule The scheduled module to be compiled. + baseline_target : Target + The baseline target to compile the original module. target : Target The target to compile the scheduled module. + dev_type : str + The device type to run the module via rpc. rpc_config : RPCConfig The RPCConfig to run the scheduled module. f_input_generator : Union[str, Callable] @@ -173,7 +171,7 @@ def validate_correctness( Returns ------- - result : ... + result : bool The result of the validation. """ @@ -195,13 +193,10 @@ def build_and_run(mod: IRModule, target: Target, dev_type: str) -> np.ndarray: lib=rt_mod, dev_type=dev_type, args={i: v for i, v in enumerate(inputs)}, # pylint: disable=unnecessary-comprehension - continuation=create_computer(backend="tir"), + continuation=create_calculator(backend="tir"), backend="tir", ) - # make targets - target = Target(target) - baseline_target = Target(baseline_target) # fetch functions & prepare inputs if isinstance(f_input_generator, str): f_input_generator = get_global_func(f_input_generator) @@ -233,11 +228,14 @@ def build_and_run(mod: IRModule, target: Target, dev_type: str) -> np.ndarray: def main(): """Main function""" describe() - database = ms.database.JSONDatabase( - path_workload=ARGS.path_workload, path_tuning_record=ARGS.path_tuning_record - ) - assert Target(ARGS.target).kind.name in ["llvm", "cuda"] - dev_type = "cpu" if Target(ARGS.target).kind.name == "llvm" else "cuda" + database = ms.database.create(work_dir=ARGS.work_dir) + target = ARGS.target + if target.kind.name == "llvm": + dev_type = "cpu" + elif target.kind.name == "cuda": + dev_type = "cuda" + else: + raise RuntimeError(f"Unsupported target kind: {target.kind.name}") records = database.get_all_tuning_records() with ms.Profiler() as profiler: for i, record in enumerate(records): @@ -247,12 +245,12 @@ def main(): sch = Schedule(original_mod) record.trace.apply_to_schedule(sch=sch, remove_postproc=False) scheduled_mod = sch.mod - flag = False + is_success = False try: - flag = validate_correctness( + is_success = validate_correctness( original_mod=original_mod, scheduled_mod=scheduled_mod, - target=ARGS.target, + target=target, baseline_target=ARGS.baseline_target, dev_type=dev_type, rpc_config=ARGS.rpc_config, @@ -268,7 +266,7 @@ def main(): ] ) ) - if flag: + if is_success: print( f"Progress {i+1: 6d} / {len(records): 6d} checked," f" used {float(profiler.get()[scope_name]): 3.3f} sec." From 5541d7c93ef9ed1caf8348be9b1415df865aa811 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Wed, 9 Nov 2022 11:30:31 -0800 Subject: [PATCH 11/12] Add total time statistics. --- python/tvm/meta_schedule/testing/validate_database.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tvm/meta_schedule/testing/validate_database.py b/python/tvm/meta_schedule/testing/validate_database.py index b4beda461eec..5e48bfb6b04e 100644 --- a/python/tvm/meta_schedule/testing/validate_database.py +++ b/python/tvm/meta_schedule/testing/validate_database.py @@ -275,6 +275,7 @@ def main(): return print("Validation passed!") + print(f"Total time spent: {float(profiler.get()['Total']): 3.3f} sec.") if __name__ == "__main__": From 6dff2e1d01e1b0556729635846de592df05de964 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Wed, 9 Nov 2022 11:59:55 -0800 Subject: [PATCH 12/12] Fix lint. --- python/tvm/meta_schedule/testing/custom_builder_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/meta_schedule/testing/custom_builder_runner.py b/python/tvm/meta_schedule/testing/custom_builder_runner.py index 63c3c57e5cd1..7129546dd8b7 100644 --- a/python/tvm/meta_schedule/testing/custom_builder_runner.py +++ b/python/tvm/meta_schedule/testing/custom_builder_runner.py @@ -143,7 +143,7 @@ def run_module_via_rpc( rpc_config: "RPCConfig", lib: Union["Module", "Executable"], dev_type: str, - args: Dict[Union[int, str], "np.ndarray"], + args: Union[Dict[int, "np.ndarray"], Dict[str, "np.ndarray"]], continuation: Callable, backend: Optional[str] = "graph", ):