diff --git a/python/tvm/contrib/hexagon/meta_schedule.py b/python/tvm/contrib/hexagon/meta_schedule.py index aaf3f8c7f8d5..dcc7d232d8c4 100644 --- a/python/tvm/contrib/hexagon/meta_schedule.py +++ b/python/tvm/contrib/hexagon/meta_schedule.py @@ -17,7 +17,14 @@ """Meta schedule tuning utilities for Hexagon.""" import os import tempfile -from typing import Callable, List, Optional +from typing import Callable, Dict, List, Optional +import tvm + +from tvm.ir.module import IRModule +from tvm.runtime import Module, NDArray +from tvm.target import Target +from tvm.driver import build as tvm_build +from tvm.tir.transform import RemoveWeightLayoutRewriteBlock from tvm.contrib.popen_pool import PopenPoolExecutor from tvm.meta_schedule.utils import cpu_count, derived_object from tvm.meta_schedule.builder import LocalBuilder @@ -121,14 +128,24 @@ def _worker_func(hexagon_launcher, evaluator_config, alloc_repeat, artifact_path return costs -def get_hexagon_local_builder(): +def get_hexagon_local_builder(pass_context: tvm.transform.PassContext = None): """Return Hexagon-compatible Builder for meta schedule.""" def export_func(mod): binary_path = export_module(mod, tempfile.mkdtemp()) return str(binary_path) - return LocalBuilder(f_export=export_func) + def default_build_with_context( + mod: IRModule, target: Target, _params: Optional[Dict[str, NDArray]] + ) -> Module: + with pass_context: + mod = RemoveWeightLayoutRewriteBlock(skip_ndarray_rewrite=True)(mod) + return tvm_build(mod, target=target) + + if pass_context is not None: + return LocalBuilder(f_build=default_build_with_context, f_export=export_func) + else: + return LocalBuilder(f_export=export_func) def get_hexagon_rpc_runner( diff --git a/python/tvm/tir/tensor_intrin/hexagon.py b/python/tvm/tir/tensor_intrin/hexagon.py index 306c8cd2e14e..49c12c3e9dce 100644 --- a/python/tvm/tir/tensor_intrin/hexagon.py +++ b/python/tvm/tir/tensor_intrin/hexagon.py @@ -20,98 +20,100 @@ from .. import TensorIntrin -@T.prim_func -def dot_product_32x4_u8u8i32_desc( - A: T.Buffer((4,), "uint8", offset_factor=1), - B: T.Buffer((32, 4), "uint8", offset_factor=1), - C: T.Buffer((32,), "int32", offset_factor=1), -) -> None: - with T.block("root"): - T.reads(C[0:32], A[0:4], B[0:32, 0:4]) - T.writes(C[0:32]) - for i in T.serial(0, 32): - for k in T.serial(0, 4): - with T.block("update"): - vi, vk = T.axis.remap("SR", [i, k]) - C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32") - - -@T.prim_func -def dot_product_32x4_u8u8i32_vrmpy( - A: T.Buffer((4,), "uint8", offset_factor=1), - B: T.Buffer((32, 4), "uint8", offset_factor=1), - C: T.Buffer((32,), "int32", offset_factor=1), -) -> None: - with T.block("root"): - T.reads(C[0:32], A[0:4], B[0:32, 0:4]) - T.writes(C[0:32]) - - A_u8x4 = A.vload([0], "uint8x4") - A_i32 = T.reinterpret(A_u8x4, dtype="int32") - - B_i8x128 = B.vload([0, 0], dtype="uint8x128") - B_i32x32 = T.reinterpret(B_i8x128, dtype="int32x32") - - C[T.ramp(T.int32(0), 1, 32)] = T.call_llvm_pure_intrin( - T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpyub.acc.128B"), - T.uint32(3), - C[T.ramp(T.int32(0), 1, 32)], - B_i32x32, - A_i32, - dtype="int32x32", - ) - - -@T.prim_func -def dot_product_32x4_u8i8i32_desc( - A: T.Buffer((4,), "uint8", offset_factor=1), - B: T.Buffer((32, 4), "int8", offset_factor=1), - C: T.Buffer((32,), "int32", offset_factor=1), -) -> None: - with T.block("root"): - T.reads(C[0:32], A[0:4], B[0:32, 0:4]) - T.writes(C[0:32]) - for i in T.serial(0, 32): - for k in T.serial(0, 4): - with T.block("update"): - vi, vk = T.axis.remap("SR", [i, k]) - C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32") - - -@T.prim_func -def dot_product_32x4_u8i8i32_vrmpy( - A: T.Buffer((4,), "uint8", offset_factor=1), - B: T.Buffer((32, 4), "int8", offset_factor=1), - C: T.Buffer((32,), "int32", offset_factor=1), -) -> None: - with T.block("root"): - T.reads(C[0:32], A[0:4], B[0:32, 0:4]) - T.writes(C[0:32]) - - A_u8x4 = A.vload([0], "uint8x4") - A_i32 = T.reinterpret(A_u8x4, dtype="int32") - - B_i8x128 = B.vload([0, 0], dtype="int8x128") - B_i32x32 = T.reinterpret(B_i8x128, dtype="int32x32") - - C[T.ramp(T.int32(0), 1, 32)] = T.call_llvm_pure_intrin( - T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpybusv.acc.128B"), - T.uint32(3), - C[T.ramp(T.int32(0), 1, 32)], - T.broadcast(A_i32, 32), - B_i32x32, - dtype="int32x32", - ) +def generate_dot_product_32x4_u8u8i32(mem_scope="global"): + @T.prim_func + def dot_product_32x4_u8u8i32_desc(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (4,), "uint8", offset_factor=1, scope=mem_scope) + B = T.match_buffer(b, (32, 4), "uint8", offset_factor=1, scope=mem_scope) + C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scope) + with T.block("root"): + T.reads(C[0:32], A[0:4], B[0:32, 0:4]) + T.writes(C[0:32]) + for i in T.serial(0, 32): + for k in T.serial(0, 4): + with T.block("update"): + vi, vk = T.axis.remap("SR", [i, k]) + C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32") + + @T.prim_func + def dot_product_32x4_u8u8i32_vrmpy(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (4,), "uint8", offset_factor=1, scope=mem_scope) + B = T.match_buffer(b, (32, 4), "uint8", offset_factor=1, scope=mem_scope) + C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scope) + with T.block("root"): + T.reads(C[0:32], A[0:4], B[0:32, 0:4]) + T.writes(C[0:32]) + + A_u8x4 = A.vload([0], "uint8x4") + A_i32 = T.reinterpret(A_u8x4, dtype="int32") + + B_i8x128 = B.vload([0, 0], dtype="uint8x128") + B_i32x32 = T.reinterpret(B_i8x128, dtype="int32x32") + + C[T.ramp(T.int32(0), 1, 32)] = T.call_llvm_pure_intrin( + T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpyub.acc.128B"), + T.uint32(3), + C[T.ramp(T.int32(0), 1, 32)], + B_i32x32, + A_i32, + dtype="int32x32", + ) + + return dot_product_32x4_u8u8i32_desc, dot_product_32x4_u8u8i32_vrmpy + + +def generate_dot_product_32x4_u8i8i32(mem_scope="global"): + @T.prim_func + def dot_product_32x4_u8i8i32_desc(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (4,), "uint8", offset_factor=1, scope=mem_scope) + B = T.match_buffer(b, (32, 4), "int8", offset_factor=1, scope=mem_scope) + C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scope) + with T.block("root"): + T.reads(C[0:32], A[0:4], B[0:32, 0:4]) + T.writes(C[0:32]) + for i in T.serial(0, 32): + for k in T.serial(0, 4): + with T.block("update"): + vi, vk = T.axis.remap("SR", [i, k]) + C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32") + + @T.prim_func + def dot_product_32x4_u8i8i32_vrmpy(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (4,), "uint8", offset_factor=1, scope=mem_scope) + B = T.match_buffer(b, (32, 4), "int8", offset_factor=1, scope=mem_scope) + C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scope) + with T.block("root"): + T.reads(C[0:32], A[0:4], B[0:32, 0:4]) + T.writes(C[0:32]) + + A_u8x4 = A.vload([0], "uint8x4") + A_i32 = T.reinterpret(A_u8x4, dtype="int32") + + B_i8x128 = B.vload([0, 0], dtype="int8x128") + B_i32x32 = T.reinterpret(B_i8x128, dtype="int32x32") + + C[T.ramp(T.int32(0), 1, 32)] = T.call_llvm_pure_intrin( + T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpybusv.acc.128B"), + T.uint32(3), + C[T.ramp(T.int32(0), 1, 32)], + T.broadcast(A_i32, 32), + B_i32x32, + dtype="int32x32", + ) + + return dot_product_32x4_u8i8i32_desc, dot_product_32x4_u8i8i32_vrmpy VRMPY_u8u8i32_INTRIN = "dot_32x4_u8u8i32_vrmpy" -TensorIntrin.register( - VRMPY_u8u8i32_INTRIN, dot_product_32x4_u8u8i32_desc, dot_product_32x4_u8u8i32_vrmpy -) +TensorIntrin.register(VRMPY_u8u8i32_INTRIN, *generate_dot_product_32x4_u8u8i32()) VRMPY_u8i8i32_INTRIN = "dot_32x4_u8i8i32_vrmpy" -TensorIntrin.register( - VRMPY_u8i8i32_INTRIN, dot_product_32x4_u8i8i32_desc, dot_product_32x4_u8i8i32_vrmpy -) +TensorIntrin.register(VRMPY_u8i8i32_INTRIN, *generate_dot_product_32x4_u8i8i32()) + +VRMPY_u8u8i32_VTCM_INTRIN = "dot_32x4_u8u8i32_vtcm_vrmpy" +TensorIntrin.register(VRMPY_u8u8i32_VTCM_INTRIN, *generate_dot_product_32x4_u8u8i32("global.vtcm")) + +VRMPY_u8i8i32_VTCM_INTRIN = "dot_32x4_u8i8i32_vtcm_vrmpy" +TensorIntrin.register(VRMPY_u8i8i32_VTCM_INTRIN, *generate_dot_product_32x4_u8i8i32("global.vtcm")) diff --git a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py index 91eb67bbf457..e15b0a4e7ddb 100644 --- a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py +++ b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py @@ -18,7 +18,8 @@ import os import tempfile -from typing import Optional +from types import MappingProxyType +from typing import Any, Mapping, Optional import numpy as np import pytest @@ -34,7 +35,11 @@ from tvm.meta_schedule import postproc, schedule_rule from tvm.tir.schedule import BlockRV, Schedule from tvm.tir.schedule.analysis import has_block -from tvm.tir.tensor_intrin.hexagon import VRMPY_u8i8i32_INTRIN, VRMPY_u8u8i32_INTRIN +from tvm.tir.tensor_intrin.hexagon import ( + VRMPY_u8i8i32_INTRIN, + VRMPY_u8u8i32_INTRIN, + VRMPY_u8i8i32_VTCM_INTRIN, +) from ..infrastructure import get_hexagon_target @@ -133,7 +138,6 @@ def tune_vrmpy_auto_tensorize(mod, params, hexagon_launcher): # from 36 to 23, with negligible performance difference. module_equality="anchor-block", ) - return ms.relay_integration.compile_relay( database=database, mod=mod, @@ -142,10 +146,13 @@ def tune_vrmpy_auto_tensorize(mod, params, hexagon_launcher): ) -@pytest.mark.skip("End-to-end tuning is skipped on CI.") @tvm.testing.requires_hexagon def test_resnet50(hexagon_launcher): """Test Resnet50.""" + + if tvm.testing.utils.IS_IN_CI: + pytest.skip("Skipping test since it takes too long in CI.") + if not os.path.exists(MODEL_JSON): pytest.skip(msg="Run python export_models.py first.") @@ -200,6 +207,44 @@ def test_resnet50(hexagon_launcher): print(debug_ex.profile(input_name=inp.copy())) +def evaluate_mod(hexagon_launcher, hexagon_lowered, llvm_lowered, input_name, inp, benchmark=False): + """Evaluate the Modules against llvm version.""" + with hexagon_launcher.create_session() as session: + graph_mod = session.get_executor_from_factory(hexagon_lowered) + graph_mod.set_input(input_name, inp.copy()) + graph_mod.run() + output = graph_mod.get_output(0).numpy() + + llvm_graph_mod = tvm.contrib.graph_executor.GraphModule(llvm_lowered["default"](tvm.cpu(0))) + llvm_graph_mod.set_input(input_name, inp.copy()) + llvm_graph_mod.run() + ref_result = llvm_graph_mod.get_output(0).numpy() + + if benchmark: + time_ms = graph_mod.benchmark(session.device, number=1, repeat=1).mean * 1e3 + print("hexagon time elapsed: ", time_ms) + debug_ex = session.get_graph_debug_executor( + hexagon_lowered.get_graph_json(), hexagon_lowered.lib + ) + print(debug_ex.profile(input_name=inp.copy())) + + np.testing.assert_allclose(ref_result, output, atol=1e-4, rtol=1e-5) + + +def load_model(): + """Load renset50 model.""" + if not os.path.exists(MODEL_JSON): + pytest.skip(msg="Run python export_models.py first.") + + with open(MODEL_JSON, "r") as file: + mod = tvm.ir.load_json(file.read()) + + with open(MODEL_PARAMS, "rb") as file: + params = relay.load_param_dict(file.read()) + + return mod, params + + def _schedule_packed_8x8x32_conv2d(): """Manually schedule a conv2d block, created from TE compute op via CreatePrimFunc, using 8x8x32 packed layout. @@ -268,22 +313,39 @@ def index_map_nchw32c_nchw8h8w32c(n_batch, channel, height, width, channel_32): return schedule_fn -def tune_packed_8x8x32_template(mod, params, hexagon_launcher): +def tune_conv2d_template( + mod, + scheduler, + schedule_tag, + params, + hexagon_launcher, + pass_config: Mapping[str, Any] = MappingProxyType({}), +): """Generate packed 8*8*32 template.""" - def schedule_rule_conv2d_packed_8x8x32(sch: Schedule, conv2d_block: BlockRV): - _schedule_packed_8x8x32_conv2d()(sch, conv2d_block) + def schedule_rule_conv2d(sch: Schedule, conv2d_block: BlockRV): + scheduler()(sch, conv2d_block) return [sch] - register_func("meta_schedule.conv2d_NCHWc_int8.hexagon", schedule_rule_conv2d_packed_8x8x32) + register_func( + "meta_schedule.conv2d_NCHWc_int8.{}.hexagon".format(schedule_tag), schedule_rule_conv2d + ) def schedule_conv2d_for_tune(sch: Schedule): - _schedule_packed_8x8x32_conv2d()(sch) + scheduler()(sch) # This line is necessary for link-params to take effect during # task extraction and relay.build(...). mod = mod.with_attr("executor", EXECUTOR) + pass_context = None + if len(pass_config.items()) > 0: + pass_context = ( + tvm.transform.PassContext(opt_level=3, config=pass_config) + if pass_config is not None + else None + ) + with tempfile.TemporaryDirectory() as work_dir: database = ms.relay_integration.tune_relay( mod=mod, @@ -294,8 +356,8 @@ def schedule_conv2d_for_tune(sch: Schedule): max_trials_per_task=1, num_trials_per_iter=1, strategy="replay-trace", - builder=get_hexagon_local_builder(), - runner=get_hexagon_rpc_runner(hexagon_launcher, number=20), + builder=get_hexagon_local_builder(pass_context), + runner=get_hexagon_rpc_runner(hexagon_launcher, number=1), # Apply MS auto scheduling rules for all blocks, but utilize # the custom block scheduling strategy registered above for # blocks annotated as `schedule_rule:meta_schedule.conv2d_NCHWc_int8` @@ -318,33 +380,37 @@ def schedule_conv2d_for_tune(sch: Schedule): # are treated as distinct tuning tasks. module_equality="ignore-ndarray", ) + + # Add default options so that it still uses the base config. + pass_config["relay.backend.use_meta_schedule"] = True + pass_config["relay.backend.tir_converter"] = "default" return ms.relay_integration.compile_relay( database=database, mod=mod, target=TARGET_HEXAGON, params=params, + pass_config=pass_config, ) -@pytest.mark.skip("End-to-end tuning is skipped on CI.") @tvm.testing.requires_hexagon def test_packed_8x8x32_resnet50(hexagon_launcher): """Test packed 8*8*32 Resnet50""" - if not os.path.exists(MODEL_JSON): - pytest.skip(msg="Run python export_models.py first.") - with open(MODEL_JSON, "r") as file: - mod = tvm.ir.load_json(file.read()) + if tvm.testing.utils.IS_IN_CI: + pytest.skip("Skipping test since it takes too long in CI.") + + mod, params = load_model() - with open(MODEL_PARAMS, "rb") as file: - params = relay.load_param_dict(file.read()) inp = np.random.randn(1, 3, 224, 224).astype("float32") input_name = "image" do_tune = True if do_tune: - hexagon_lowered = tune_packed_8x8x32_template(mod, params, hexagon_launcher) + hexagon_lowered = tune_conv2d_template( + mod, _schedule_packed_8x8x32_conv2d, "packed_8x8x32", params, hexagon_launcher + ) else: with tvm.transform.PassContext(opt_level=3): hexagon_lowered = relay.build( @@ -361,18 +427,112 @@ def test_packed_8x8x32_resnet50(hexagon_launcher): params=params, ) - with hexagon_launcher.start_session() as session: - graph_mod = session.get_executor_from_factory(hexagon_lowered) - graph_mod.set_input(input_name, inp.copy()) - graph_mod.run() - hexagon_output = graph_mod.get_output(0).numpy() + evaluate_mod(hexagon_launcher, hexagon_lowered, llvm_lowered, input_name, inp) - llvm_graph_mod = tvm.contrib.graph_executor.GraphModule(llvm_lowered["default"](tvm.cpu(0))) - llvm_graph_mod.set_input(input_name, inp.copy()) - llvm_graph_mod.run() - ref_result = llvm_graph_mod.get_output(0).numpy() - np.testing.assert_allclose(ref_result, hexagon_output, atol=1e-4, rtol=1e-5) +def _schedule_async_dma_conv2d(): + """Manually schedule a conv2d block, created from TE compute op via CreatePrimFunc, + using 8x8x32 packed layout. + """ + + def schedule_fn(sch, conv2d_block: Optional[BlockRV] = None) -> bool: + if conv2d_block is None: + if has_block(sch, "conv2d_NCHWc_int8"): + conv2d_block = sch.get_block("conv2d_NCHWc_int8") + else: + return False + + assert "conv2d_NCHWc_int8" in sch.get(conv2d_block).annotations["schedule_rule"] + + # Apply scheduling + + post_blocks = sch.get_consumers(conv2d_block) + if len(post_blocks) > 0: + # Fuse all intermediate post ops into the last op. + # This is equivalent to the traverse_inline function used in TE schedules. + while True: + next_post_blocks = [] + for post_block in post_blocks: + next_consumers = sch.get_consumers(post_block) + if len(next_consumers) > 0: + sch.compute_inline(post_block) + next_post_blocks += next_consumers + if len(next_post_blocks) == 0: + assert len(post_blocks) == 1 + outer_block = post_blocks[0] + break + post_blocks = next_post_blocks + else: + outer_block = conv2d_block + + # Move the conv2d mma into the injective post mma compute block + if outer_block != conv2d_block: + loops = sch.get_loops(outer_block) + # Compute at the second loop for pipelining. + sch.compute_at(conv2d_block, loops[1], preserve_unit_loops=True) + + # Add cache for input and output for copying data to vtcm. + input_a_cache = sch.cache_read(conv2d_block, 0, "global.vtcm") + sch.compute_at(input_a_cache, sch.get_loops(conv2d_block)[1]) + sch.fuse(*sch.get_loops(input_a_cache)[2:]) + + input_b_cache = sch.cache_read(conv2d_block, 1, "global.vtcm") + sch.compute_at(input_b_cache, sch.get_loops(conv2d_block)[1]) + sch.fuse(*sch.get_loops(input_b_cache)[2:]) + + output_cache_write = sch.cache_write(conv2d_block, 0, "global.vtcm") + sch.fuse(*sch.get_loops(output_cache_write)[2:]) + + conv2d_loops = sch.get_loops(block=conv2d_block) + o_c, k_h, k_w, x_0, x_1, i_c = conv2d_loops[-6:] + ic_o, ic_i = sch.split(loop=i_c, factors=[None, 4], preserve_unit_iters=True) + oc_o, oc_i = sch.split(loop=o_c, factors=[None, 32], preserve_unit_iters=True) + sch.reorder(oc_o, k_h, k_w, x_0, x_1, ic_o, oc_i, ic_i) + new_loops = sch.get_loops(block=conv2d_block) + sch.parallel(new_loops[4]) + sch.unroll(new_loops[5]) + # TODO(nverke): Add compute optimizations here. + sch.blockize(loop=oc_i) + + sch.tensorize(oc_i, VRMPY_u8i8i32_VTCM_INTRIN) + + pipeline_loop = conv2d_loops[1] + sch.annotate(pipeline_loop, "software_pipeline_stage", [0, 0, 1, 2, 3]) + sch.annotate(pipeline_loop, "software_pipeline_order", [0, 1, 2, 3, 4]) + sch.annotate(pipeline_loop, "software_pipeline_async_stages", [0, 2]) + + return True + + return schedule_fn + + +@tvm.testing.requires_hexagon +def test_async_dma_resnet50(hexagon_launcher): + """Test async dma Resnet50""" + + if tvm.testing.utils.IS_IN_CI: + pytest.skip("Skipping test since it takes too long in CI.") + + mod, params = load_model() + + inp = np.random.randn(1, 3, 224, 224).astype("float32") + input_name = "image" + + pass_config = { + "tir.use_async_copy": 1, + "tir.merge_async_commit_queue_scope": False, + "relay.backend.use_meta_schedule": True, + "relay.backend.tir_converter": "default", + } + + hexagon_lowered = tune_conv2d_template( + mod, _schedule_async_dma_conv2d, "async_dma", params, hexagon_launcher, pass_config + ) + with tvm.transform.PassContext(opt_level=3): + llvm_lowered = tvm.relay.build( + mod, tvm.target.Target(TARGET_LLVM, host=TARGET_LLVM), params=params + ) + evaluate_mod(hexagon_launcher, hexagon_lowered, llvm_lowered, input_name, inp, True) if __name__ == "__main__":