diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py index 52c7f44fcede..973cbf19bece 100644 --- a/python/tvm/auto_scheduler/relay_integration.py +++ b/python/tvm/auto_scheduler/relay_integration.py @@ -84,6 +84,7 @@ def extract_tasks( include_simple_tasks=False, dump_workload_to_dag_log=None, opt_level=3, + other_targets=None, ): """Extract tuning tasks from a relay program. @@ -105,6 +106,8 @@ def extract_tasks( A file to dump an association between the workload keys and the actual DAG opt_level : Optional[int] The optimization level of the task extractions. + other_targets: Optional[List[tvm.target.Target]] + Other targets for call_all_topi_funcs, e.g., cutlass target. Returns ------- @@ -125,12 +128,15 @@ def extract_tasks( old_verbose = dispatch_ctx.verbose dispatch_ctx.verbose = 0 + targets = [target] + if other_targets is not None: + targets += other_targets errors = [] with env: # Wrap build call in a new thread to avoid the conflict # between python's multiprocessing and tvm's thread pool build_thread = threading.Thread( - target=call_all_topi_funcs, args=(mod, params, target, errors, opt_level) + target=call_all_topi_funcs, args=(mod, params, targets, errors, opt_level) ) build_thread.start() build_thread.join() diff --git a/tests/python/contrib/test_cutlass.py b/tests/python/contrib/test_cutlass.py index 753ee178f9d3..f3d2e98e8937 100644 --- a/tests/python/contrib/test_cutlass.py +++ b/tests/python/contrib/test_cutlass.py @@ -15,13 +15,16 @@ # specific language governing permissions and limitations # under the License. import logging +import tempfile import math import tvm from tvm import relay from tvm.contrib.cudnn import conv_output_shape import numpy as np +from tvm.relay import op as _op from tvm.runtime.vm import VirtualMachine from tvm.relay.op.contrib.cutlass import partition_for_cutlass +from tvm import auto_scheduler from tvm.relay.transform import FirstOrderGradient, ToMixedPrecision, InferType from tvm.contrib.cutlass import ( has_cutlass, @@ -235,6 +238,32 @@ def get_conv2d_backward_weight( ) +def get_dense_transpose_dense(M, N, K, dtype="float16"): + """ + output = nn.dense(_op.transpose(nn.dense(input, weight0), axes=(1, 0)), weight1) + + dense0: [M, K] * [N, K] -> [M, N] + transpose: [M, N] -> [N, M] + dense1: [N, M] * [K, M] -> [N, K] + + input: [M, K] + weight0: [N, K] + weight1: [K, M] + """ + input_shape = (M, K) + weight0_shape = (N, K) + weight1_shape = (K, M) + + input = relay.var("input", shape=input_shape, dtype=dtype) + weight0 = relay.var("weight0", shape=weight0_shape, dtype=dtype) + weight1 = relay.var("weight1", shape=weight1_shape, dtype=dtype) + + output0 = relay.nn.dense(input, weight0, out_dtype=dtype) + input1 = _op.transpose(output0, axes=(1, 0)) + output = relay.nn.dense(input1, weight1, out_dtype=dtype) + return output + + def convert_conv2d_layout(mod, desired_layouts): with tvm.transform.PassContext(opt_level=3): seq = tvm.transform.Sequential([relay.transform.ConvertLayout(desired_layouts)]) @@ -257,6 +286,8 @@ def profile_and_build( tmp_dir="./tmp", use_fast_math=False, use_3xtf32=True, + use_ansor=False, + ansor_tuning=False, ): logging.info("before partitioning:\n%s", mod) mod = partition_for_cutlass(mod) @@ -279,8 +310,53 @@ def profile_and_build( }, host=host, ) - with tvm.transform.PassContext(opt_level=3): - lib = relay.build(mod, target=[cuda, cutlass], params=params) + + if use_ansor: + with tvm.transform.PassContext( + opt_level=3, config={"relay.backend.use_auto_scheduler": True} + ): + tasks, task_weights = auto_scheduler.extract_tasks( + mod, params, cuda, include_simple_tasks=True, opt_level=3, other_targets=[cutlass] + ) + for idx, (task, task_weight) in enumerate(zip(tasks, task_weights)): + logging.info( + f"==== Task {idx}: {task.desc} (weight {task_weight} key: {task.workload_key}) =====" + ) + logging.info(task.compute_dag) + + with tempfile.NamedTemporaryFile() as fp: + log_file = fp.name + + # auto-tuning is disabled by default + if ansor_tuning: + measure_ctx = auto_scheduler.LocalRPCMeasureContext( + repeat=3, min_repeat_ms=200, timeout=10 + ) + tuner = auto_scheduler.TaskScheduler(tasks, task_weights) + tuner.tune( + auto_scheduler.TuningOptions( + num_measure_trials=100, + runner=measure_ctx.runner, + measure_callbacks=[ + auto_scheduler.RecordToFile(log_file), + ], + ) + ) + + with auto_scheduler.ApplyHistoryBest(log_file): + with tvm.transform.PassContext( + opt_level=3, + config={"relay.backend.use_auto_scheduler": True}, + ): + lib = relay.build( + mod, + target=cuda, + target_host=host, + params=params, + ) + else: + with tvm.transform.PassContext(opt_level=3): + lib = relay.build(mod, target=[cuda, cutlass], params=params) lib = finalize_modules(lib, "compile.so", tmp_dir) dev = tvm.device("cuda", 0) rt_mod = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) @@ -959,5 +1035,65 @@ def test_conv2d_bwd(): ) +def verify_dense_transpose_dense( + func, + M, + N, + K, + ref_target="cuda", + sm=80, + atol=1e-5, + rtol=1e-5, + run_benchmark=False, + dtype="float16", + use_3xtf32=True, +): + assert has_cutlass() + if sm < 80 and dtype == "float32": + return + + mod = tvm.IRModule.from_expr(func) + typ = relay.transform.InferType()(mod)["main"].body.checked_type + np_data = get_random_ndarray((M, K), dtype) + np_weight0 = get_random_ndarray((N, K), dtype) + np_weight1 = get_random_ndarray((K, M), dtype) + + params = {"weight0": np_weight0, "weight1": np_weight1} + + rt_mod_ref, dev = get_ref_rt_mod(mod, params, target=ref_target) + cutlass_rt_mod, dev, num_partition = profile_and_build( + mod, + params, + sm, + use_3xtf32=use_3xtf32, + use_ansor=False, + ) + cutlass_ansor_rt_mod, dev, num_partition = profile_and_build( + mod, + params, + sm, + use_3xtf32=use_3xtf32, + use_ansor=True, + ) + x = tvm.nd.array(np_data, device=dev) + cutlass_out = get_output(cutlass_rt_mod, ["input"], [x]) + cutlass_ansor_out = get_output(cutlass_ansor_rt_mod, ["input"], [x]) + ref_out = get_output(rt_mod_ref, ["input"], [x]) + + assert num_partition > 0 + np.testing.assert_allclose(cutlass_out, ref_out, atol=atol, rtol=rtol) + np.testing.assert_allclose(cutlass_ansor_out, ref_out, atol=atol, rtol=rtol) + + if run_benchmark: + print("CUTLASS:", cutlass_rt_mod.benchmark(dev, number=1, repeat=600)) + print("CUTLASS with Ansor:", cutlass_ansor_rt_mod.benchmark(dev, number=1, repeat=600)) + print("TVM with target %s:" % ref_target, rt_mod_ref.benchmark(dev, number=1, repeat=600)) + + +@tvm.testing.requires_cutlass +def test_dense_transpose_dense(): + verify_dense_transpose_dense(get_dense_transpose_dense(M, N, K), M, N, K) + + if __name__ == "__main__": tvm.testing.main()