Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion python/tvm/auto_scheduler/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def extract_tasks(
include_simple_tasks=False,
dump_workload_to_dag_log=None,
opt_level=3,
other_targets=None,
):
"""Extract tuning tasks from a relay program.

Expand All @@ -105,6 +106,8 @@ def extract_tasks(
A file to dump an association between the workload keys and the actual DAG
opt_level : Optional[int]
The optimization level of the task extractions.
other_targets: Optional[List[tvm.target.Target]]
Other targets for call_all_topi_funcs, e.g., cutlass target.

Returns
-------
Expand All @@ -125,12 +128,15 @@ def extract_tasks(
old_verbose = dispatch_ctx.verbose
dispatch_ctx.verbose = 0

targets = [target]
if other_targets is not None:
targets += other_targets
errors = []
with env:
# Wrap build call in a new thread to avoid the conflict
# between python's multiprocessing and tvm's thread pool
build_thread = threading.Thread(
target=call_all_topi_funcs, args=(mod, params, target, errors, opt_level)
target=call_all_topi_funcs, args=(mod, params, targets, errors, opt_level)
)
build_thread.start()
build_thread.join()
Expand Down
140 changes: 138 additions & 2 deletions tests/python/contrib/test_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@
# specific language governing permissions and limitations
# under the License.
import logging
import tempfile
import math
import tvm
from tvm import relay
from tvm.contrib.cudnn import conv_output_shape
import numpy as np
from tvm.relay import op as _op
from tvm.runtime.vm import VirtualMachine
from tvm.relay.op.contrib.cutlass import partition_for_cutlass
from tvm import auto_scheduler
from tvm.relay.transform import FirstOrderGradient, ToMixedPrecision, InferType
from tvm.contrib.cutlass import (
has_cutlass,
Expand Down Expand Up @@ -235,6 +238,32 @@ def get_conv2d_backward_weight(
)


def get_dense_transpose_dense(M, N, K, dtype="float16"):
"""
output = nn.dense(_op.transpose(nn.dense(input, weight0), axes=(1, 0)), weight1)

dense0: [M, K] * [N, K] -> [M, N]
transpose: [M, N] -> [N, M]
dense1: [N, M] * [K, M] -> [N, K]

input: [M, K]
weight0: [N, K]
weight1: [K, M]
"""
input_shape = (M, K)
weight0_shape = (N, K)
weight1_shape = (K, M)

input = relay.var("input", shape=input_shape, dtype=dtype)
weight0 = relay.var("weight0", shape=weight0_shape, dtype=dtype)
weight1 = relay.var("weight1", shape=weight1_shape, dtype=dtype)

output0 = relay.nn.dense(input, weight0, out_dtype=dtype)
input1 = _op.transpose(output0, axes=(1, 0))
output = relay.nn.dense(input1, weight1, out_dtype=dtype)
return output


def convert_conv2d_layout(mod, desired_layouts):
with tvm.transform.PassContext(opt_level=3):
seq = tvm.transform.Sequential([relay.transform.ConvertLayout(desired_layouts)])
Expand All @@ -257,6 +286,8 @@ def profile_and_build(
tmp_dir="./tmp",
use_fast_math=False,
use_3xtf32=True,
use_ansor=False,
ansor_tuning=False,
):
logging.info("before partitioning:\n%s", mod)
mod = partition_for_cutlass(mod)
Expand All @@ -279,8 +310,53 @@ def profile_and_build(
},
host=host,
)
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=[cuda, cutlass], params=params)

if use_ansor:
with tvm.transform.PassContext(
opt_level=3, config={"relay.backend.use_auto_scheduler": True}
):
tasks, task_weights = auto_scheduler.extract_tasks(
mod, params, cuda, include_simple_tasks=True, opt_level=3, other_targets=[cutlass]
)
for idx, (task, task_weight) in enumerate(zip(tasks, task_weights)):
logging.info(
f"==== Task {idx}: {task.desc} (weight {task_weight} key: {task.workload_key}) ====="
)
logging.info(task.compute_dag)

with tempfile.NamedTemporaryFile() as fp:
log_file = fp.name

# auto-tuning is disabled by default
if ansor_tuning:
measure_ctx = auto_scheduler.LocalRPCMeasureContext(
repeat=3, min_repeat_ms=200, timeout=10
)
tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
tuner.tune(
auto_scheduler.TuningOptions(
num_measure_trials=100,
runner=measure_ctx.runner,
measure_callbacks=[
auto_scheduler.RecordToFile(log_file),
],
)
)

with auto_scheduler.ApplyHistoryBest(log_file):
with tvm.transform.PassContext(
opt_level=3,
config={"relay.backend.use_auto_scheduler": True},
):
lib = relay.build(
mod,
target=cuda,
target_host=host,
params=params,
)
else:
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=[cuda, cutlass], params=params)
lib = finalize_modules(lib, "compile.so", tmp_dir)
dev = tvm.device("cuda", 0)
rt_mod = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))
Expand Down Expand Up @@ -959,5 +1035,65 @@ def test_conv2d_bwd():
)


def verify_dense_transpose_dense(
func,
M,
N,
K,
ref_target="cuda",
sm=80,
atol=1e-5,
rtol=1e-5,
run_benchmark=False,
dtype="float16",
use_3xtf32=True,
):
assert has_cutlass()
if sm < 80 and dtype == "float32":
return

mod = tvm.IRModule.from_expr(func)
typ = relay.transform.InferType()(mod)["main"].body.checked_type
np_data = get_random_ndarray((M, K), dtype)
np_weight0 = get_random_ndarray((N, K), dtype)
np_weight1 = get_random_ndarray((K, M), dtype)

params = {"weight0": np_weight0, "weight1": np_weight1}

rt_mod_ref, dev = get_ref_rt_mod(mod, params, target=ref_target)
cutlass_rt_mod, dev, num_partition = profile_and_build(
mod,
params,
sm,
use_3xtf32=use_3xtf32,
use_ansor=False,
)
cutlass_ansor_rt_mod, dev, num_partition = profile_and_build(
mod,
params,
sm,
use_3xtf32=use_3xtf32,
use_ansor=True,
)
x = tvm.nd.array(np_data, device=dev)
cutlass_out = get_output(cutlass_rt_mod, ["input"], [x])
cutlass_ansor_out = get_output(cutlass_ansor_rt_mod, ["input"], [x])
ref_out = get_output(rt_mod_ref, ["input"], [x])

assert num_partition > 0
np.testing.assert_allclose(cutlass_out, ref_out, atol=atol, rtol=rtol)
np.testing.assert_allclose(cutlass_ansor_out, ref_out, atol=atol, rtol=rtol)

if run_benchmark:
print("CUTLASS:", cutlass_rt_mod.benchmark(dev, number=1, repeat=600))
print("CUTLASS with Ansor:", cutlass_ansor_rt_mod.benchmark(dev, number=1, repeat=600))
print("TVM with target %s:" % ref_target, rt_mod_ref.benchmark(dev, number=1, repeat=600))


@tvm.testing.requires_cutlass
def test_dense_transpose_dense():
verify_dense_transpose_dense(get_dense_transpose_dense(M, N, K), M, N, K)


if __name__ == "__main__":
tvm.testing.main()