Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self) -> None:
)

def get(self) -> Dict[str, float]:
"""Get the profiling results in minutes"""
"""Get the profiling results in seconds"""
return _ffi_api.ProfilerGet(self) # type: ignore # pylint: disable=no-member

def table(self) -> str:
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/meta_schedule/testing/custom_builder_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""Customized builder and runner methods"""
# pylint: disable=import-outside-toplevel

from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Union, Callable

if TYPE_CHECKING:
import numpy as np # type: ignore
Expand Down Expand Up @@ -143,7 +143,7 @@ def run_module_via_rpc(
rpc_config: "RPCConfig",
lib: Union["Module", "Executable"],
dev_type: str,
args: Dict[str, "np.ndarray"],
args: Union[Dict[int, "np.ndarray"], Dict[str, "np.ndarray"]],
continuation: Callable,
backend: Optional[str] = "graph",
):
Expand Down
55 changes: 51 additions & 4 deletions python/tvm/meta_schedule/testing/tune_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def create_timer(backend: str) -> Callable:

def f_timer(
rt_mod: Union[tvm.runtime.Module, tvm.runtime.vm.Executable],
dev: tvm.device,
dev: tvm.runtime.Device,
input_data: Dict[str, NDArray],
) -> None:
"""Run and benchmark the given runtime module, print out the result.
Expand All @@ -95,7 +95,7 @@ def f_timer(
----------
rt_mod : Union[tvm.runtime.Module, tvm.runtime.vm.Executable]
The runtime module or vm executable.
dev : tvm.device
dev : tvm.runtime.Device
The device type to run workload.
input_data : Dict[str, np.ndarray]
The input data as a dictionary.
Expand Down Expand Up @@ -152,7 +152,7 @@ def create_time_per_layer(graph: str) -> Callable:

def f_time_per_layer(
rt_mod: tvm.runtime.Module,
dev: tvm.device,
dev: tvm.runtime.Device,
input_data: Dict[str, NDArray],
) -> None:
"""Run and benchmark the per-layer performance of given runtime module,
Expand All @@ -162,7 +162,7 @@ def f_time_per_layer(
----------
rt_mod : tvm.runtime.Module
The runtime module.
dev : tvm.device
dev : tvm.runtime.Device
The device type to run workload.
input_data : Dict[str, np.ndarray]
The input data as a dictionary.
Expand Down Expand Up @@ -192,3 +192,50 @@ def f_time_per_layer(
)

return f_time_per_layer


def create_calculator(backend: str) -> Callable:
"""Create a function to fetch the computing result of running the given runtime module.

Parameters
----------
backend : str
The backend to use, only tir is supported for now.

Returns
-------
func : Callable
The function to fetch the computing result.
"""

def f_calculator(
rt_mod: tvm.runtime.Module,
dev: tvm.runtime.Device, # pylint: disable=unused-argument
input_data: Dict[str, NDArray],
) -> List[NDArray]:
"""Fetch the result of running the given runtime module.

Parameters
----------
rt_mod : Union[tvm.runtime.Module, tvm.runtime.vm.Executable]
The runtime module or vm executable.
dev : tvm.device
The device type to run workload.
input_data : Dict[str, np.ndarray]
The input data as a dictionary.
"""
try:
if backend == "tir":
data = [v for _, v in sorted(input_data.items(), key=lambda x: x[0])]
rt_mod(*data)
return data
else:
raise ValueError(f"Backend {backend} not supported in f_calculator!")

except Exception as exc: # pylint: disable=broad-except
print(
f"Run module f_calculator via RPC failed, exception: {exc}",
)
return None

return f_calculator
282 changes: 282 additions & 0 deletions python/tvm/meta_schedule/testing/validate_database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""JSON Database validation script"""
from typing import Union, Callable, List
from distutils.util import strtobool
import argparse
import logging
import warnings
import numpy as np # type: ignore

import tvm
from tvm.target import Target
from tvm.ir import IRModule
from tvm.tir import Schedule
from tvm import meta_schedule as ms
from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc
from tvm.meta_schedule.testing.tune_utils import create_calculator, generate_input_data
from tvm._ffi import get_global_func, register_func
from tvm.support import describe

DELIMITOR = "\n" + "-" * 30 + "\n"


def _parse_args():
args = argparse.ArgumentParser()
args.add_argument(
"--work-dir",
type=str,
required=True,
help="The path to the work directory containing database files.",
)
args.add_argument(
"--target",
type=Target,
required=True,
)
args.add_argument(
"--baseline-target",
type=Target,
default="llvm -num-cores=1",
required=False,
help="The baseline target to compile the original module.",
)
args.add_argument(
"--rpc-host",
type=str,
required=True,
)
args.add_argument(
"--rpc-port",
type=int,
required=True,
)
args.add_argument(
"--rpc-key",
type=str,
required=True,
)
args.add_argument(
"--number",
type=int,
default=3,
)
args.add_argument(
"--repeat",
type=int,
default=1,
)
args.add_argument(
"--min-repeat-ms",
type=int,
default=100,
)
args.add_argument(
"--cpu-flush",
type=lambda x: bool(strtobool(x)),
help="example: True / False",
required=True,
)
parsed = args.parse_args()
parsed.target = tvm.target.Target(parsed.target)
parsed.rpc_config = ms.runner.RPCConfig(
tracker_host=parsed.rpc_host,
tracker_port=parsed.rpc_port,
tracker_key=parsed.rpc_key,
session_timeout_sec=600,
)
if parsed.cpu_flush and parsed.target.kind.name != "llvm":
warnings.warn("cpu_flush is only supported on llvm target")
return parsed


# logging
logging.basicConfig(
format="%(asctime)s.%(msecs)03d %(levelname)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)
logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG)

# arg parser
ARGS = _parse_args()


@register_func("tvm.meta_schedule.testing.default_input_generator")
def default_input_generator(mod: IRModule) -> List[tvm.nd.NDArray]:
args_info = ms.arg_info.TensorInfo.from_prim_func(mod["main"])
inputs = [
tvm.nd.array(generate_input_data(input_shape=arg_info.shape, input_dtype=arg_info.dtype))
for arg_info in args_info
]
return inputs


@register_func("tvm.meta_schedule.testing.default_check_metric")
def default_check_metric(a: List[tvm.nd.NDArray], b: List[tvm.nd.NDArray]) -> bool:
assert len(a) == len(b), "Different number of outputs from two modules"
for i, _ in enumerate(a):
if not np.allclose(a[i].numpy(), b[i].numpy(), rtol=1e-3, atol=2e-3):
return False
return True


def validate_correctness(
original_mod: IRModule, # compiled for "baseline_target"
scheduled_mod: IRModule, # compiled for "target"
*,
baseline_target: Target,
target: Target,
dev_type: str,
rpc_config: ms.runner.RPCConfig,
f_input_generator: Union[
str, Callable[[IRModule], List[tvm.nd.NDArray]]
] = default_input_generator,
f_check_metric: Union[
str, Callable[[tvm.nd.NDArray, tvm.nd.NDArray], bool]
] = default_check_metric,
) -> bool:
"""Function to validate the correctness of a scheduled module.

Parameters
----------
original_mod : IRModule
The original module to be compiled.
scheduled_mod : IRModule
The scheduled module to be compiled.
baseline_target : Target
The baseline target to compile the original module.
target : Target
The target to compile the scheduled module.
dev_type : str
The device type to run the module via rpc.
rpc_config : RPCConfig
The RPCConfig to run the scheduled module.
f_input_generator : Union[str, Callable]
The function to generate the input data.
f_check_metric : Union[str, Callable]
The function to check the metric.

Returns
-------
result : bool
The result of the validation.
"""

def to_numpy(a: List[tvm.nd.NDArray]) -> List[np.ndarray]:
"""Convert a list of TVM NDArray to a list of numpy array"""
assert a is not None, "Empty result cannot be converted to numpy"
return [x.numpy() for x in a]

def to_tvm_ndarray(a: List[np.ndarray]) -> List[tvm.nd.NDArray]:
"""Convert a list of numpy array to a list of TVM NDArray"""
assert a is not None, "Empty result cannot be converted to TVM NDArray"
return [tvm.nd.array(x) for x in a]

def build_and_run(mod: IRModule, target: Target, dev_type: str) -> np.ndarray:
"""Build and run the module on the target device."""
rt_mod = tvm.build(mod, target=target)
return run_module_via_rpc(
rpc_config=rpc_config,
lib=rt_mod,
dev_type=dev_type,
args={i: v for i, v in enumerate(inputs)}, # pylint: disable=unnecessary-comprehension
continuation=create_calculator(backend="tir"),
backend="tir",
)

# fetch functions & prepare inputs
if isinstance(f_input_generator, str):
f_input_generator = get_global_func(f_input_generator)
if isinstance(f_check_metric, str):
f_check_metric = get_global_func(f_check_metric)
inputs = to_numpy(f_input_generator(original_mod)) # type: ignore
# build & run original result
original_res = to_numpy(build_and_run(original_mod, target=baseline_target, dev_type="cpu"))
scheduled_res = to_numpy(build_and_run(scheduled_mod, target=target, dev_type=dev_type))
# check metric
if f_check_metric(to_tvm_ndarray(original_res), to_tvm_ndarray(scheduled_res)): # type: ignore
return True
else:
print(
("\n\n").join(
[
"Validation failed!",
"Original Result:" + DELIMITOR + str(original_res),
"Scheduled Result:" + DELIMITOR + str(scheduled_res),
"Input:" + DELIMITOR + str(inputs),
"Original IRModule:" + DELIMITOR + original_mod.script(),
"Scheduled IRModule:" + DELIMITOR + scheduled_mod.script(),
]
)
)
return False


def main():
"""Main function"""
describe()
database = ms.database.create(work_dir=ARGS.work_dir)
target = ARGS.target
if target.kind.name == "llvm":
dev_type = "cpu"
elif target.kind.name == "cuda":
dev_type = "cuda"
else:
raise RuntimeError(f"Unsupported target kind: {target.kind.name}")
records = database.get_all_tuning_records()
with ms.Profiler() as profiler:
for i, record in enumerate(records):
scope_name = f"validate #{i}"
with profiler.timeit(scope_name):
original_mod = record.workload.mod
sch = Schedule(original_mod)
record.trace.apply_to_schedule(sch=sch, remove_postproc=False)
scheduled_mod = sch.mod
is_success = False
try:
is_success = validate_correctness(
original_mod=original_mod,
scheduled_mod=scheduled_mod,
target=target,
baseline_target=ARGS.baseline_target,
dev_type=dev_type,
rpc_config=ARGS.rpc_config,
)
except Exception as e: # pylint: disable=broad-except, invalid-name
print(
("\n\n").join(
[
"Validation failed!",
"Original IRModule:" + DELIMITOR + original_mod.script(),
"Scheduled IRModule:" + DELIMITOR + scheduled_mod.script(),
"Exception" + DELIMITOR + str(e),
]
)
)
if is_success:
print(
f"Progress {i+1: 6d} / {len(records): 6d} checked,"
f" used {float(profiler.get()[scope_name]): 3.3f} sec."
)
else:
return

print("Validation passed!")
print(f"Total time spent: {float(profiler.get()['Total']): 3.3f} sec.")


if __name__ == "__main__":
main()