diff --git a/python/tvm/meta_schedule/testing/torchbench/run.py b/python/tvm/meta_schedule/testing/torchbench/run.py index 20c633196900..5df77cf25c3f 100644 --- a/python/tvm/meta_schedule/testing/torchbench/run.py +++ b/python/tvm/meta_schedule/testing/torchbench/run.py @@ -89,19 +89,25 @@ ``` """ # pylint: disable=logging-format-interpolation + import argparse -import functools +import contextlib import logging +import os +import sys import warnings +from collections import defaultdict from enum import Enum -from typing import Callable, List, Tuple +from typing import Callable, List, Tuple, Dict import numpy as np # type: ignore import torch # type: ignore +from scipy.stats import ttest_ind # type: ignore + import tvm import tvm.relay -from scipy.stats import ttest_ind # type: ignore from tvm import meta_schedule as ms +from tvm._ffi import get_global_func from tvm.contrib.graph_executor import GraphModule from tvm.meta_schedule.testing.torchbench.utils import ( load_torchdynamo_benchmark_runner, @@ -201,6 +207,13 @@ def parse_args(): https://github.com/pytorch/benchmark/tree/main/torchbenchmark/models. """, ) + args.add_argument( + "--float32", + action="store_true", + help=""" + Cast model and inputs to fp32 + """, + ) # Tuning-related config args.add_argument( @@ -217,6 +230,12 @@ def parse_args(): The working directory to save intermediate results and store databases for compilation. """, ) + args.add_argument( + "--strategy", + type=str, + default="evolutionary", + help="The search strategy used by MetaSchdule.", + ) args.add_argument( "--num-trials", type=int, @@ -293,6 +312,10 @@ def parse_args(): ) parsed = args.parse_args() + + # Trim all args, otherwise it confuses the arg parser of timm_efficientdet + sys.argv = sys.argv[:1] + return parsed @@ -311,6 +334,7 @@ def parse_args(): runner = load_torchdynamo_benchmark_runner( # pylint: disable=invalid-name IS_CUDA, cosine_similarity=ARGS.result_metric == ResultComparisonMetric.COSINE, + float32=ARGS.float32, ) @@ -343,26 +367,49 @@ def get_meta_schedule_runner() -> ms.runner.PyRunner: return ms.runner.LocalRunner() -def get_graph_executor_forward(mod: GraphModule, device: tvm.runtime.Device) -> Callable: +def get_graph_executor_forward( + graph_executor_factory: tvm.runtime.Module, device: tvm.runtime.Device +) -> Callable: """ Get the forward function for graph executor, in order to integrate with TorchDynamo. """ - def forward(*args): - if IS_CUDA: - torch.cuda.synchronize() - args = tuple(arg.contiguous() for arg in args) - for idx, arg in enumerate(args, 0): - mod.set_input( - f"inp_{idx}", - tvm.nd.from_dlpack(arg), - ) - mod.run() - device.sync() - result = [torch.from_dlpack(mod.get_output(i)) for i in range(mod.get_num_outputs())] - return result + # It has to lazily import this package, loading the C++ PyTorch integration + # after the transformers package is imported when loading model. Otherwise + # there will be segfault caused by the protobuf library. + import tvm.contrib.torch # pylint: disable=import-outside-toplevel, unused-import, redefined-outer-name - return forward + save_runtime_mod = get_global_func("tvmtorch.save_runtime_mod", allow_missing=True) + if save_runtime_mod is None: + warnings.warn( + "C++ PyTorch TVM integration is missing. Fallback to Python forward function." + "Build TVM with 'USE_PT_TVMDSOOP' to enable the C++ custom operator" + ) + mod = GraphModule(graph_executor_factory["default"](device)) + + def forward(*args): + if IS_CUDA: + torch.cuda.synchronize() + args = tuple(arg.detach().contiguous() for arg in args) + for idx, arg in enumerate(args, 0): + mod.set_input( + f"inp_{idx}", + tvm.nd.from_dlpack(arg), + ) + mod.run() + device.sync() + result = [torch.from_dlpack(mod.get_output(i)) for i in range(mod.get_num_outputs())] + return result + + return forward + else: + save_runtime_mod(graph_executor_factory.module) + module = torch.classes.tvm_torch.GraphExecutorFactoryWrapper() + + def forward(*args): # type: ignore # isort: skip, pylint: disable=function-redefined + return module.forward(args) + + return forward def get_vm_forward(virtual_machine: VirtualMachine, device: tvm.runtime.Device) -> Callable: @@ -373,7 +420,7 @@ def get_vm_forward(virtual_machine: VirtualMachine, device: tvm.runtime.Device) def forward(*args): if IS_CUDA: torch.cuda.synchronize() - args = tuple(tvm.nd.from_dlpack(arg.contiguous()) for arg in args) + args = tuple(tvm.nd.from_dlpack(arg.detach().contiguous()) for arg in args) result = virtual_machine.invoke("main", *args) device.sync() @@ -384,13 +431,36 @@ def forward(*args): return forward -def create_tvm_task_collection_backend(tasks: List[ms.ExtractedTask]) -> Callable: +def create_tvm_task_collection_backend() -> Tuple[Callable, List[ms.ExtractedTask]]: """ This torchdynamo backend only collects the extracted tasks from MetaSchedule. It doesn't tune the model. """ + subgraph_idx = 0 + subgraphs_dir = os.path.join(ARGS.work_dir, "subgraphs") + os.makedirs(subgraphs_dir, exist_ok=True) + + collected_tasks = [] + task_index: Dict[int, List[ms.ExtractedTask]] = defaultdict(list) + + def collect_task(task): + task_hash = tvm.ir.structural_hash(task.dispatched[0]) + + for duplicate_task in task_index[task_hash]: + if tvm.ir.structural_equal(duplicate_task.dispatched[0], task.dispatched[0]): + duplicate_task.weight += task.weight + return + + task_index[task_hash].append(task) + collected_tasks.append(task) + def backend(graph_module, example_inputs): + nonlocal subgraph_idx + + torch.save(graph_module, os.path.join(subgraphs_dir, f"graph_module_{subgraph_idx}")) + torch.save(example_inputs, os.path.join(subgraphs_dir, f"example_inputs_{subgraph_idx}")) + jit_mod = torch.jit.trace(graph_module, example_inputs) shape_list = [(f"inp_{idx}", i.shape) for idx, i in enumerate(example_inputs)] ir_mod, params = tvm.relay.frontend.from_pytorch(jit_mod, shape_list) @@ -400,12 +470,21 @@ def backend(graph_module, example_inputs): target=ARGS.target, params=params, ) - logger.info("Extracted %d tasks", len(extracted_tasks)) - tasks.extend(extracted_tasks) + old_tasks_count = len(collected_tasks) + for task in extracted_tasks: + collect_task(task) + logger.info( + "Extracted %d tasks from graph %d, with %d new tasks", + len(extracted_tasks), + subgraph_idx, + len(collected_tasks) - old_tasks_count, + ) + + subgraph_idx += 1 return graph_module.forward - return backend + return backend, collected_tasks def create_tvm_compilation_backend(database: ms.database.Database) -> Callable: @@ -429,8 +508,7 @@ def backend(graph_module, example_inputs): device = tvm.cuda(0) if IS_CUDA else tvm.cpu(0) if ARGS.backend == "graph": - mod = GraphModule(lib["default"](device)) - return get_graph_executor_forward(mod, device) + return get_graph_executor_forward(lib, device) elif ARGS.backend == "vm": vm = VirtualMachine(lib, device) # pylint: disable=invalid-name return get_vm_forward(vm, device) @@ -463,6 +541,67 @@ def is_output_correct(output: torch.Tensor, expected: torch.Tensor) -> bool: raise RuntimeError(f"Unknown comparison metric {comparison_metric}") +def inspect_output_error(output, expected): + """ + Inpsect the error between the actual output and expected output. + """ + if not isinstance(output, torch.Tensor): + logger.info( + f"Unsupported type for error inspection: {type(output).__name__}." + f"Please manually check output.pt" + ) + return + output = output.cpu().float() + expected = expected.cpu().float() + + abs_error = (output - expected).abs() + rel_error = (abs_error / expected).abs() + + def format_error_table(error, bins) -> str: + bin_tensor = torch.as_tensor([float(b) for b in bins], dtype=error.dtype) + error_hist = torch.histogram(error, bin_tensor).hist.int() + return "\n".join(f"< {b}\t{e}" for e, b in zip(error_hist, bins[1:])) + + abs_error_bins = [ + "-1e10", + "0", + "1e-8", + "1e-6", + "1e-5", + "1e-4", + "1e-3", + "1e-2", + "1e-1", + "1", + "1e10", + ] + rel_error_bins = [ + "-1e10", + "0", + "1e-4", + "1e-3", + "1e-2", + "1e-1", + "1", + "1e1", + "1e2", + "1e3", + "1e100", + ] + + large_rel_error_idx = rel_error > 1 + abs_error_with_large_rel_error = abs_error[large_rel_error_idx] + + logger.error(f"Expected (PyTorch eager): {expected}") + logger.error(f"Actual (Optimized): {output}") + logger.error(f"Absolute Error\n{format_error_table(abs_error, abs_error_bins)}") + logger.error(f"Relative Error\n{format_error_table(rel_error, rel_error_bins)}") + logger.error( + f"Max absolute error for position with large relative error (> 1):" + f"{abs_error_with_large_rel_error.max()}" + ) + + def performance_experiment( model_iter_fn: Callable, model: torch.nn.Module, @@ -473,6 +612,8 @@ def performance_experiment( Simplified from https://github.com/pytorch/torchdynamo/blob/c537639f9712621dc04ca09908796dbbe86c354b/benchmarks/common.py#L494 pylint: disable=line-too-long """ timings = np.zeros((ARGS.benchmark_repeat, 2), np.float64) + if IS_CUDA: + torch.cuda.empty_cache() is_correct = True @@ -500,10 +641,11 @@ def performance_experiment( f"optimized:{format_time(median[1])} " f"speedup:{speedup:.3f}x p:{pvalue:.3f}" ) + torch.save(actual_output, os.path.join(ARGS.work_dir, "output.pt")) + torch.save(expected_output, os.path.join(ARGS.work_dir, "expected.pt")) if not is_correct: logger.error("Result is incorrect.") - logger.error(f"Expected (PyTorch eager): {expected_output}") - logger.error(f"Actual (Optimized): {actual_output}") + inspect_output_error(actual_output, expected_output) return "" @@ -523,7 +665,10 @@ def main(): """ describe() - database = ms.database.JSONDatabase(work_dir=ARGS.work_dir) + meta_schedule_work_dir = os.path.join(ARGS.work_dir, "meta_schedule") + os.makedirs(meta_schedule_work_dir, exist_ok=True) + + database = ms.database.JSONDatabase(work_dir=meta_schedule_work_dir) if not ARGS.mode.should_tune: if len(database) == 0: raise RuntimeError( @@ -539,45 +684,54 @@ def main(): ARGS.cpu_flush = False try: + logger.info(f"Loading model with batch size: {ARGS.batch_size}") _, name, model, example_inputs, batch_size = runner.load_model( get_torch_device_type(ARGS.target), ARGS.model, batch_size=ARGS.batch_size, ) - logger.info( - f"batch size: {batch_size} input shape: {[input.shape for input in example_inputs]}" - ) + model, example_inputs = runner.maybe_cast(model, example_inputs) + logger.info(f"Got model with batch size: {batch_size}") except NotImplementedError: - logging.exception(f"{ARGS.model} failed to load") - return + logger.exception(f"{ARGS.model} failed to load") + raise + + with contextlib.ExitStack() as stack: + profiler = stack.enter_context(ms.Profiler()) + stack.enter_context(torch.no_grad()) + + if ARGS.mode.should_tune: + task_collect_backend, extracted_tasks = create_tvm_task_collection_backend() + task_collect_ctx = torchdynamo.optimize(task_collect_backend) + task_collect_ctx(runner.model_iter_fn)(model, example_inputs) + + tasks, task_weights = ms.relay_integration.extracted_tasks_to_tune_contexts( + extracted_tasks=extracted_tasks, + work_dir=ARGS.work_dir, + strategy=ARGS.strategy, + ) + database = ms.tune.tune_tasks( + tasks=tasks, + task_weights=task_weights, + work_dir=ARGS.work_dir, + max_trials_global=ARGS.num_trials, + max_trials_per_task=ARGS.max_trials_per_task, + runner=get_meta_schedule_runner(), # type: ignore + database=database, + cost_model=ms.cost_model.XGBModel( # type: ignore + extractor=ms.feature_extractor.PerStoreFeature(), + adaptive_training=ARGS.adaptive_training, + ), + ) - if ARGS.mode.should_tune: - extracted_tasks: List[ms.ExtractedTask] = [] - task_collect_ctx = torchdynamo.optimize(create_tvm_task_collection_backend(extracted_tasks)) - task_collect_ctx(runner.model_iter_fn)(model, example_inputs) - tasks, task_weights = ms.relay_integration.extracted_tasks_to_tune_contexts( - extracted_tasks=extracted_tasks, - work_dir=ARGS.work_dir, - ) - database = ms.tune.tune_tasks( - tasks=tasks, - task_weights=task_weights, - work_dir=ARGS.work_dir, - max_trials_global=ARGS.num_trials, - max_trials_per_task=ARGS.num_trials_per_task, - runner=get_meta_schedule_runner(), # type: ignore - database=database, - cost_model=ms.cost_model.XGBModel( # type: ignore - extractor=ms.feature_extractor.PerStoreFeature(), - adaptive_training=ARGS.adaptive_training, - ), - ) + if ARGS.mode.should_eval: + torchdynamo.reset() + model_compile_ctx = torchdynamo.optimize(create_tvm_compilation_backend(database)) + model_compile_ctx(runner.model_iter_fn)(model, example_inputs) + with torch.no_grad(): + performance_experiment(runner.model_iter_fn, model, example_inputs) - if ARGS.mode.should_eval: - torchdynamo.reset() - model_compile_ctx = torchdynamo.optimize(create_tvm_compilation_backend(database)) - experiment = functools.partial(performance_experiment, runner.model_iter_fn) - runner.run_one_model(name, model, example_inputs, model_compile_ctx, experiment) + print(profiler.table()) if __name__ == "__main__": diff --git a/python/tvm/meta_schedule/testing/torchbench/utils.py b/python/tvm/meta_schedule/testing/torchbench/utils.py index f5a745ea008a..8bd022a9cb18 100644 --- a/python/tvm/meta_schedule/testing/torchbench/utils.py +++ b/python/tvm/meta_schedule/testing/torchbench/utils.py @@ -51,7 +51,9 @@ def find_torchdynamo() -> str: DYNAMO_DIR = find_torchdynamo() -sys.path.append(DYNAMO_DIR) +sys.path.insert( + 0, DYNAMO_DIR +) # opacus_cifar10 depends on opacus, which installs a package called 'benchmarks' sys.path.append(f"{DYNAMO_DIR}/benchmarks") # pylint: disable=wrong-import-position, unused-import @@ -62,7 +64,7 @@ def find_torchdynamo() -> str: def load_torchdynamo_benchmark_runner( - is_cuda: bool, cosine_similarity: bool = False + is_cuda: bool, cosine_similarity: bool = False, float32: bool = False ) -> TorchBenchmarkRunner: """ Load the benchmark runner from TorchDynamo. @@ -86,7 +88,7 @@ class RunnerArgs: cosine: bool = False # Whether to use consine similarity to check if output is correct. - args = RunnerArgs(cosine=cosine_similarity) + args = RunnerArgs(cosine=cosine_similarity, float32=float32) runner = TorchBenchmarkRunner() runner.args = args