From 985895034d39b4d3819033f57a3c81d78a17ba0f Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Sat, 22 Oct 2022 23:13:50 -0700 Subject: [PATCH 1/8] Add test that demonstrates applying a custom TIR schedule to E2E model. --- .../metaschedule_e2e/test_resnet50_int8.py | 162 ++++++++++++++++++ 1 file changed, 162 insertions(+) diff --git a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py index 4c8d91dd27ef..e57240b0420f 100644 --- a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py +++ b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py @@ -18,14 +18,17 @@ import numpy as np import pytest import tempfile +from typing import Optional import tvm import tvm.testing from tvm import relay +from tvm._ffi import register_func from tvm.meta_schedule import postproc, schedule_rule from tvm.tir.tensor_intrin.hexagon import VRMPY_u8i8i32_INTRIN, VRMPY_u8u8i32_INTRIN from tvm.contrib.hexagon.meta_schedule import get_hexagon_local_builder, get_hexagon_rpc_runner from tvm import meta_schedule as ms +from tvm.tir.schedule import BlockRV, Schedule from ..infrastructure import get_hexagon_target @@ -184,3 +187,162 @@ def test_resnet50(hexagon_launcher): hexagon_lowered.get_graph_json(), hexagon_lowered.lib ) print(debug_ex.profile(input_name=inp.copy())) + + +def _schedule_packed_8x8x32_conv2d(do_tune: bool): + """Manually schedule a conv2d block, created from TE compute op via CreatePrimFunc, + using 8x8x32 packed layout. + """ + + def schedule_fn(sch, conv2d_block: Optional[BlockRV] = None) -> bool: + if sch.mod.attrs is not None and "conv2d" not in sch.mod.attrs["task_name"]: + return False + if conv2d_block == None: + conv2d_block = sch.get_block("compute") + assert "conv2d_NCHWc_int8" in sch.get(conv2d_block).annotations["schedule_rule"] + + # Apply scheduling + + post_blocks = sch.get_consumers(conv2d_block) + if len(post_blocks) > 0: + # Fuse all intermediate post ops into the last op. + # This is equivalent to the traverse_inline function used in TE schedules. + while True: + next_post_blocks = [] + for post_block in post_blocks: + next_consumers = sch.get_consumers(post_block) + if len(next_consumers) > 0: + sch.compute_inline(post_block) + next_post_blocks += next_consumers + if len(next_post_blocks) == 0: + assert len(post_blocks) == 1 + outer_block = post_blocks[0] + break + post_blocks = next_post_blocks + else: + outer_block = conv2d_block + + # Move the conv2d mma into the injective post mma compute block + if outer_block != conv2d_block: + loops = sch.get_loops(outer_block) + # TODO(csullivan): May want to move this to an interior loop + # of the outer block doing injective/ewise ops + sch.compute_at(conv2d_block, loops[0]) + + def index_map_nchw32c_nchw8h8w32c(n, c, h, w, c32): + return [n, c, h // 8, w // 8, h % 8, w % 8, c32] + + # sch.cache_read() + # sch.transform_layout + if do_tune: + pass + else: + pass + + return True + + return schedule_fn + + +def tune_packed_8x8x32_template(mod, params, hexagon_launcher): + def schedule_rule_conv2d_packed_8x8x32(sch: Schedule, conv2d_block: BlockRV): + _schedule_packed_8x8x32_conv2d(do_tune=True)(sch, conv2d_block) + return [sch] + + register_func("meta_schedule.conv2d_NCHWc_int8", schedule_rule_conv2d_packed_8x8x32) + + # This line is necessary for link-params to take effect during + # task extraction and relay.build(...). + mod = mod.with_attr("executor", executor) + + with tempfile.TemporaryDirectory() as work_dir: + database = ms.relay_integration.tune_relay( + mod=mod, + target=target, + params=params, + work_dir=work_dir, + # for faster tuning + max_trials_global=20000, + max_trials_per_task=8, + num_trials_per_iter=8, + strategy="replay-trace", + # max_trials_global=20000, + # num_trials_per_iter=32, + # max_trials_per_task=128, + # strategy="evolutionary", + builder=get_hexagon_local_builder(), + runner=get_hexagon_rpc_runner(hexagon_launcher, number=20), + # TODO(csullivan): How can I pass in the template here instead + space=ms.space_generator.PostOrderApply( + f_block_filter=None, + sch_rules="from-target", + postprocs=[], + mutator_probs="from-target", + ), + # Without this, the same workloads with different constant weights + # are treated as distinct tuning tasks. + module_equality="ignore-ndarray", + ) + + return ms.relay_integration.compile_relay( + database=database, + mod=mod, + target=target, + params=params, + ) + + +@tvm.testing.requires_hexagon +def test_packed_8x8x32_resnet50(hexagon_launcher): + if not os.path.exists(model_json): + pytest.skip(msg="Run python export_models.py first.") + + with open(model_json, "r") as fi: + mod = tvm.ir.load_json(fi.read()) + + with open(model_params, "rb") as fi: + params = relay.load_param_dict(fi.read()) + inp = np.random.randn(1, 3, 224, 224).astype("float32") + input_name = "image" + + do_tune = True + + if do_tune: + hexagon_lowered = tune_packed_8x8x32_template(mod, params, hexagon_launcher) + else: + with tvm.transform.PassContext(opt_level=3): + hexagon_lowered = relay.build( + mod, + tvm.target.Target(target, host=target), + params=params, + executor=executor, + ) + + with tvm.transform.PassContext(opt_level=3): + llvm_lowered = tvm.relay.build( + mod, + tvm.target.Target(target_llvm, host=target_llvm), + params=params, + ) + + with hexagon_launcher.start_session() as session: + graph_mod = session.get_executor_from_factory(hexagon_lowered) + graph_mod.set_input(input_name, inp.copy()) + graph_mod.run() + hexagon_output = graph_mod.get_output(0).numpy() + + llvm_graph_mod = tvm.contrib.graph_executor.GraphModule(llvm_lowered["default"](tvm.cpu(0))) + llvm_graph_mod.set_input(input_name, inp.copy()) + llvm_graph_mod.run() + ref_result = llvm_graph_mod.get_output(0).numpy() + + np.testing.assert_allclose(ref_result, hexagon_output, atol=1e-4, rtol=1e-5) + + time_ms = graph_mod.benchmark(session.device, number=1, repeat=20).mean * 1e3 + + print("time elapsed: ", time_ms) + + debug_ex = session.get_graph_debug_executor( + hexagon_lowered.get_graph_json(), hexagon_lowered.lib + ) + print(debug_ex.profile(input_name=inp.copy())) From 0dc305b64682cfe87361ee8ed63fc20fd6e551d9 Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Sun, 23 Oct 2022 21:35:23 -0700 Subject: [PATCH 2/8] Add example scheduling that demonstrates converting input and output activation to Hexagon's blocked layout. --- .../metaschedule_e2e/test_resnet50_int8.py | 25 +++++++++++-------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py index e57240b0420f..10c3d11dc6a0 100644 --- a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py +++ b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py @@ -225,21 +225,24 @@ def schedule_fn(sch, conv2d_block: Optional[BlockRV] = None) -> bool: # Move the conv2d mma into the injective post mma compute block if outer_block != conv2d_block: loops = sch.get_loops(outer_block) - # TODO(csullivan): May want to move this to an interior loop - # of the outer block doing injective/ewise ops - sch.compute_at(conv2d_block, loops[0]) + # TODO(csullivan): Currently does all post conv2d mma steps + # directly after accumulation for one spatial pixel. May + # be desirable to do this with coarser spatial granularity + sch.compute_at(conv2d_block, loops[4]) def index_map_nchw32c_nchw8h8w32c(n, c, h, w, c32): return [n, c, h // 8, w // 8, h % 8, w % 8, c32] - # sch.cache_read() - # sch.transform_layout - if do_tune: - pass - else: - pass - - return True + # Add cache for input and output activation layout transform, + # note that weight is already in correct layout + input_cache = sch.cache_read(conv2d_block, 0, "global") + output_cache = sch.cache_write(outer_block, 0, "global") + sch.transform_layout( + conv2d_block, ("read", 0), index_map=index_map_nchw32c_nchw8h8w32c, pad_value=0 + ) + sch.transform_layout( + outer_block, ("write", 0), index_map=index_map_nchw32c_nchw8h8w32c, pad_value=0 + ) return schedule_fn From ef56d187ecb7b0871720c7686304b13620d7ccb5 Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Mon, 24 Oct 2022 08:06:00 -0700 Subject: [PATCH 3/8] Manually generate search space for test clarity. See TODO, this may also disable auto scheduling for non-convolution ops. --- .../metaschedule_e2e/test_resnet50_int8.py | 45 ++++++++++++++----- 1 file changed, 35 insertions(+), 10 deletions(-) diff --git a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py index 10c3d11dc6a0..45993e895e45 100644 --- a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py +++ b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py @@ -195,11 +195,13 @@ def _schedule_packed_8x8x32_conv2d(do_tune: bool): """ def schedule_fn(sch, conv2d_block: Optional[BlockRV] = None) -> bool: - if sch.mod.attrs is not None and "conv2d" not in sch.mod.attrs["task_name"]: - return False if conv2d_block == None: - conv2d_block = sch.get_block("compute") - assert "conv2d_NCHWc_int8" in sch.get(conv2d_block).annotations["schedule_rule"] + try: + conv2d_block = sch.get_block("conv2d_NCHWc_int8") + except: + return False + + assert "conv2d_NCHWc_int8" in sch.get(conv2d_block).annotations["schedule_rule"] # Apply scheduling @@ -237,12 +239,19 @@ def index_map_nchw32c_nchw8h8w32c(n, c, h, w, c32): # note that weight is already in correct layout input_cache = sch.cache_read(conv2d_block, 0, "global") output_cache = sch.cache_write(outer_block, 0, "global") + # Transform the layout of the input sch.transform_layout( conv2d_block, ("read", 0), index_map=index_map_nchw32c_nchw8h8w32c, pad_value=0 ) + # Transform the layout of the int32 accumulator + sch.transform_layout( + conv2d_block, ("write", 0), index_map=index_map_nchw32c_nchw8h8w32c, pad_value=0 + ) + # Transform the layout of the output sch.transform_layout( outer_block, ("write", 0), index_map=index_map_nchw32c_nchw8h8w32c, pad_value=0 ) + return True return schedule_fn @@ -252,7 +261,10 @@ def schedule_rule_conv2d_packed_8x8x32(sch: Schedule, conv2d_block: BlockRV): _schedule_packed_8x8x32_conv2d(do_tune=True)(sch, conv2d_block) return [sch] - register_func("meta_schedule.conv2d_NCHWc_int8", schedule_rule_conv2d_packed_8x8x32) + # register_func("meta_schedule.conv2d_NCHWc_int8", schedule_rule_conv2d_packed_8x8x32) + + def schedule_conv2d_for_tune(sch: Schedule): + _schedule_packed_8x8x32_conv2d(do_tune=True)(sch) # This line is necessary for link-params to take effect during # task extraction and relay.build(...). @@ -275,12 +287,25 @@ def schedule_rule_conv2d_packed_8x8x32(sch: Schedule, conv2d_block: BlockRV): # strategy="evolutionary", builder=get_hexagon_local_builder(), runner=get_hexagon_rpc_runner(hexagon_launcher, number=20), - # TODO(csullivan): How can I pass in the template here instead - space=ms.space_generator.PostOrderApply( - f_block_filter=None, - sch_rules="from-target", + # TODO(csullivan): Configrm the below is accurate + # Enable MS auto scheduling for all ops, but utilize + # the custom scheduling strategy registered above for + # convolution + # space=ms.space_generator.PostOrderApply( + # f_block_filter=None, + # sch_rules="from-target", + # postprocs=[], + # mutator_probs="from-target", + # ), + # + # Constrain search space to only be the single + # schedule provided for convolution. No auto + # scheduling will be possible. + space=ms.space_generator.ScheduleFn( + schedule_conv2d_for_tune, + sch_rules=[], postprocs=[], - mutator_probs="from-target", + mutator_probs={}, ), # Without this, the same workloads with different constant weights # are treated as distinct tuning tasks. From e74a72e5e70d04aa602e212b380706e6d1ab2f34 Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Mon, 24 Oct 2022 08:15:00 -0700 Subject: [PATCH 4/8] Skip E2E test in CI. --- .../contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py index 45993e895e45..fd711bd94e0d 100644 --- a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py +++ b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py @@ -320,6 +320,7 @@ def schedule_conv2d_for_tune(sch: Schedule): ) +@pytest.mark.skip("End-to-end tuning is skipped on CI.") @tvm.testing.requires_hexagon def test_packed_8x8x32_resnet50(hexagon_launcher): if not os.path.exists(model_json): From fc1dfea29bd756cad9c23a932383da81d5797b51 Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Mon, 24 Oct 2022 08:23:07 -0700 Subject: [PATCH 5/8] Limit trials as the search space is currently empty. --- .../test_hexagon/metaschedule_e2e/test_resnet50_int8.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py index fd711bd94e0d..cf1fb2a6e300 100644 --- a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py +++ b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py @@ -276,15 +276,10 @@ def schedule_conv2d_for_tune(sch: Schedule): target=target, params=params, work_dir=work_dir, - # for faster tuning max_trials_global=20000, - max_trials_per_task=8, - num_trials_per_iter=8, + max_trials_per_task=1, + num_trials_per_iter=1, strategy="replay-trace", - # max_trials_global=20000, - # num_trials_per_iter=32, - # max_trials_per_task=128, - # strategy="evolutionary", builder=get_hexagon_local_builder(), runner=get_hexagon_rpc_runner(hexagon_launcher, number=20), # TODO(csullivan): Configrm the below is accurate From f881e537492c7b0ed6c58cc14241982a764271f1 Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Mon, 24 Oct 2022 15:27:20 -0700 Subject: [PATCH 6/8] PR feedback: Remove latency measurement and profiling report. --- .../test_hexagon/metaschedule_e2e/test_resnet50_int8.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py index cf1fb2a6e300..729a65000f01 100644 --- a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py +++ b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py @@ -361,12 +361,3 @@ def test_packed_8x8x32_resnet50(hexagon_launcher): ref_result = llvm_graph_mod.get_output(0).numpy() np.testing.assert_allclose(ref_result, hexagon_output, atol=1e-4, rtol=1e-5) - - time_ms = graph_mod.benchmark(session.device, number=1, repeat=20).mean * 1e3 - - print("time elapsed: ", time_ms) - - debug_ex = session.get_graph_debug_executor( - hexagon_lowered.get_graph_json(), hexagon_lowered.lib - ) - print(debug_ex.profile(input_name=inp.copy())) From 29219b9dbe1bedcdcc11f72c020668889ac5de79 Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Mon, 24 Oct 2022 21:22:38 -0700 Subject: [PATCH 7/8] Update comment to reflect PR discussion summary, remove TODO. --- .../metaschedule_e2e/test_resnet50_int8.py | 36 +++++++++---------- 1 file changed, 17 insertions(+), 19 deletions(-) diff --git a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py index 729a65000f01..6519a8b56a38 100644 --- a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py +++ b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py @@ -261,7 +261,7 @@ def schedule_rule_conv2d_packed_8x8x32(sch: Schedule, conv2d_block: BlockRV): _schedule_packed_8x8x32_conv2d(do_tune=True)(sch, conv2d_block) return [sch] - # register_func("meta_schedule.conv2d_NCHWc_int8", schedule_rule_conv2d_packed_8x8x32) + register_func("meta_schedule.conv2d_NCHWc_int8", schedule_rule_conv2d_packed_8x8x32) def schedule_conv2d_for_tune(sch: Schedule): _schedule_packed_8x8x32_conv2d(do_tune=True)(sch) @@ -282,26 +282,24 @@ def schedule_conv2d_for_tune(sch: Schedule): strategy="replay-trace", builder=get_hexagon_local_builder(), runner=get_hexagon_rpc_runner(hexagon_launcher, number=20), - # TODO(csullivan): Configrm the below is accurate - # Enable MS auto scheduling for all ops, but utilize - # the custom scheduling strategy registered above for - # convolution - # space=ms.space_generator.PostOrderApply( - # f_block_filter=None, - # sch_rules="from-target", - # postprocs=[], - # mutator_probs="from-target", - # ), - # - # Constrain search space to only be the single - # schedule provided for convolution. No auto - # scheduling will be possible. - space=ms.space_generator.ScheduleFn( - schedule_conv2d_for_tune, - sch_rules=[], + # Apply MS auto scheduling rules for all blocks, but utilize + # the custom block scheduling strategy registered above for + # blocks annotated as `schedule_rule:meta_schedule.conv2d_NCHWc_int8` + space=ms.space_generator.PostOrderApply( + f_block_filter=None, + sch_rules="from-target", postprocs=[], - mutator_probs={}, + mutator_probs="from-target", ), + # Constrain search space to only be the single + # schedule provided for all blocks. No auto + # scheduling will be possible. + # space=ms.space_generator.ScheduleFn( + # schedule_conv2d_for_tune, + # sch_rules=[], + # postprocs=[], + # mutator_probs={}, + # ), # Without this, the same workloads with different constant weights # are treated as distinct tuning tasks. module_equality="ignore-ndarray", From 9c9b9e3fdf6c0b801e8e9b627fb193905961ab86 Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Sat, 29 Oct 2022 10:51:28 -0700 Subject: [PATCH 8/8] Use ScheduleFn space generator to disable any autotuning. --- .../metaschedule_e2e/test_resnet50_int8.py | 25 +++++++++---------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py index 6519a8b56a38..0a2bcc229924 100644 --- a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py +++ b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py @@ -285,26 +285,25 @@ def schedule_conv2d_for_tune(sch: Schedule): # Apply MS auto scheduling rules for all blocks, but utilize # the custom block scheduling strategy registered above for # blocks annotated as `schedule_rule:meta_schedule.conv2d_NCHWc_int8` - space=ms.space_generator.PostOrderApply( - f_block_filter=None, - sch_rules="from-target", - postprocs=[], - mutator_probs="from-target", - ), + # space=ms.space_generator.PostOrderApply( + # f_block_filter=None, + # sch_rules="from-target", + # postprocs=[], + # mutator_probs="from-target", + # ), # Constrain search space to only be the single # schedule provided for all blocks. No auto # scheduling will be possible. - # space=ms.space_generator.ScheduleFn( - # schedule_conv2d_for_tune, - # sch_rules=[], - # postprocs=[], - # mutator_probs={}, - # ), + space=ms.space_generator.ScheduleFn( + schedule_conv2d_for_tune, + sch_rules=[], + postprocs=[], + mutator_probs={}, + ), # Without this, the same workloads with different constant weights # are treated as distinct tuning tasks. module_equality="ignore-ndarray", ) - return ms.relay_integration.compile_relay( database=database, mod=mod,