From c0406a5d9600e889f406073112fcc54524042972 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Mon, 8 Aug 2022 12:37:38 -0700 Subject: [PATCH 1/8] Added optional target blocks. --- .../space_generator/post_order_apply.cc | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/meta_schedule/space_generator/post_order_apply.cc b/src/meta_schedule/space_generator/post_order_apply.cc index 50b49943f5ff..3e7c7c3a4a77 100644 --- a/src/meta_schedule/space_generator/post_order_apply.cc +++ b/src/meta_schedule/space_generator/post_order_apply.cc @@ -48,7 +48,7 @@ class BlockCollector : public tir::StmtVisitor { return results; } /*! \brief Constructor */ - explicit BlockCollector(const tir::Schedule& sch) : sch_(sch) {} + explicit BlockCollector(const tir::Schedule& sch, const Array target_blocks = {}) : sch_(sch), target_blocks_(target_blocks) {} /*! \brief Override the Stmt visiting behaviour */ void VisitStmt_(const tir::BlockNode* block) override { tir::StmtVisitor::VisitStmt_(block); @@ -56,11 +56,23 @@ class BlockCollector : public tir::StmtVisitor { << "Duplicated block name " << block->name_hint << " in function " << func_name_ << " not supported!"; block_names_.insert(block->name_hint); - blocks_to_collect_.push_back(block->name_hint); + // If target blocks are specified, only collect them. Otherwise collect all blocks. + if (target_blocks_.empty()) { + blocks_to_collect_.push_back(block->name_hint); + } else { + // Iterate over specified blocks and check if this is one of them. + for (String name : target_blocks_) { + if (name.compare(block->name_hint) == 0) { + blocks_to_collect_.push_back(block->name_hint); + } + } + } } /*! \brief The schedule to be collected */ const tir::Schedule& sch_; + /*! \brief An optional list of block names that will be collected, if not provided all blocks are collected. */ + const Array target_blocks_; /*! \brief The set of func name and block name pair */ std::unordered_set block_names_; /* \brief The list of blocks to collect in order */ From babddf74546c481522cd8bcced3fe935350c7598 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Mon, 8 Aug 2022 14:58:53 -0700 Subject: [PATCH 2/8] Checkpoint for debugging. --- include/tvm/meta_schedule/space_generator.h | 4 +- .../space_generator/post_order_apply.py | 8 +- .../space_generator/post_order_apply.cc | 33 ++++---- .../test_meta_schedule_post_order_apply.py | 83 +++++++++++++------ 4 files changed, 85 insertions(+), 43 deletions(-) diff --git a/include/tvm/meta_schedule/space_generator.h b/include/tvm/meta_schedule/space_generator.h index f7d6cac31cab..808a55a5ee77 100644 --- a/include/tvm/meta_schedule/space_generator.h +++ b/include/tvm/meta_schedule/space_generator.h @@ -132,6 +132,8 @@ class SpaceGenerator : public runtime::ObjectRef { SpaceGenerator() = default; public: + /* A callback function that can be used to filter which blocks have generated spaces. */ + using BlockFilterFunc = runtime::TypedPackedFunc(const tir::BlockNode&)>; /*! * \brief Create a design space generator with customized methods on the python-side. * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. @@ -153,7 +155,7 @@ class SpaceGenerator : public runtime::ObjectRef { * to blocks in post-DFS order. * \return The design space generator created. */ - TVM_DLL static SpaceGenerator PostOrderApply(); + TVM_DLL static SpaceGenerator PostOrderApply(BlockFilterFunc f_block_filter = nullptr); TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(SpaceGenerator, ObjectRef, SpaceGeneratorNode); }; diff --git a/python/tvm/meta_schedule/space_generator/post_order_apply.py b/python/tvm/meta_schedule/space_generator/post_order_apply.py index 80f372a448f5..6683059928b4 100644 --- a/python/tvm/meta_schedule/space_generator/post_order_apply.py +++ b/python/tvm/meta_schedule/space_generator/post_order_apply.py @@ -29,8 +29,12 @@ class PostOrderApply(SpaceGenerator): rules to blocks in post-DFS order. """ - def __init__(self): + def __init__(self, target_blocks=[]): """Constructor""" + if target_blocks is None: + target_blocks = [] + if not isinstance(target_blocks, (list, tuple)): + target_blocks = [target_blocks] self.__init_handle_by_constructor__( - _ffi_api.SpaceGeneratorPostOrderApply, # type: ignore # pylint: disable=no-member + _ffi_api.SpaceGeneratorPostOrderApply, target_blocks # type: ignore # pylint: disable=no-member ) diff --git a/src/meta_schedule/space_generator/post_order_apply.cc b/src/meta_schedule/space_generator/post_order_apply.cc index 3e7c7c3a4a77..56333f02f0a2 100644 --- a/src/meta_schedule/space_generator/post_order_apply.cc +++ b/src/meta_schedule/space_generator/post_order_apply.cc @@ -24,8 +24,8 @@ namespace meta_schedule { /*! \brief Collecting all the blocks */ class BlockCollector : public tir::StmtVisitor { public: - static Array Collect(const tir::Schedule& sch) { // - return BlockCollector(sch).Run(); + static Array Collect(const tir::Schedule& sch, const SpaceGenerator::BlockFilterFunc f_block_filter = nullptr) { // + return BlockCollector(sch, f_block_filter).Run(); } private: @@ -48,7 +48,7 @@ class BlockCollector : public tir::StmtVisitor { return results; } /*! \brief Constructor */ - explicit BlockCollector(const tir::Schedule& sch, const Array target_blocks = {}) : sch_(sch), target_blocks_(target_blocks) {} + explicit BlockCollector(const tir::Schedule& sch, const SpaceGenerator::BlockFilterFunc f_block_filter = nullptr) : sch_(sch), f_block_filter_(f_block_filter) {} /*! \brief Override the Stmt visiting behaviour */ void VisitStmt_(const tir::BlockNode* block) override { tir::StmtVisitor::VisitStmt_(block); @@ -56,23 +56,21 @@ class BlockCollector : public tir::StmtVisitor { << "Duplicated block name " << block->name_hint << " in function " << func_name_ << " not supported!"; block_names_.insert(block->name_hint); - // If target blocks are specified, only collect them. Otherwise collect all blocks. - if (target_blocks_.empty()) { - blocks_to_collect_.push_back(block->name_hint); - } else { - // Iterate over specified blocks and check if this is one of them. - for (String name : target_blocks_) { - if (name.compare(block->name_hint) == 0) { - blocks_to_collect_.push_back(block->name_hint); - } + // If filter function is provided, use it to selectively collect blocks. + if (f_block_filter_ != nullptr) { + Optional collect_block = f_block_filter_(*block); + if (collect_block.defined() && collect_block) { + blocks_to_collect_.push_back(block->name_hint); } + } else { + blocks_to_collect_.push_back(block->name_hint); } } /*! \brief The schedule to be collected */ const tir::Schedule& sch_; - /*! \brief An optional list of block names that will be collected, if not provided all blocks are collected. */ - const Array target_blocks_; + /*! \brief An optional packed func that allows only certain blocks to be collected. */ + const SpaceGenerator::BlockFilterFunc f_block_filter_; /*! \brief The set of func name and block name pair */ std::unordered_set block_names_; /* \brief The list of blocks to collect in order */ @@ -93,6 +91,8 @@ class PostOrderApplyNode : public SpaceGeneratorNode { Array sch_rules_{nullptr}; /*! \brief The logging function to use. */ PackedFunc logging_func; + /*! \brief Optional block names to target. If not specified all blocks will have spaces generated. */ + SpaceGenerator::BlockFilterFunc f_block_filter_ = nullptr; void VisitAttrs(tvm::AttrVisitor* v) { // `rand_state_` is not visited @@ -119,7 +119,7 @@ class PostOrderApplyNode : public SpaceGeneratorNode { Array result{sch}; // Enumerate the schedule rules first because you can // always concat multiple schedule rules as one - Array all_blocks = BlockCollector::Collect(sch); + Array all_blocks = BlockCollector::Collect(sch, f_block_filter_); Array> rules{NullOpt}; rules.insert(rules.end(), sch_rules_.begin(), sch_rules_.end()); for (Optional sch_rule : rules) { @@ -189,8 +189,9 @@ class PostOrderApplyNode : public SpaceGeneratorNode { TVM_DECLARE_FINAL_OBJECT_INFO(PostOrderApplyNode, SpaceGeneratorNode); }; -SpaceGenerator SpaceGenerator::PostOrderApply() { +SpaceGenerator SpaceGenerator::PostOrderApply(SpaceGenerator::BlockFilterFunc f_block_filter) { ObjectPtr n = make_object(); + n->f_block_filter_ = f_block_filter; return SpaceGenerator(n); } diff --git a/tests/python/unittest/test_meta_schedule_post_order_apply.py b/tests/python/unittest/test_meta_schedule_post_order_apply.py index 21d29ac74d82..fa4295b03631 100644 --- a/tests/python/unittest/test_meta_schedule_post_order_apply.py +++ b/tests/python/unittest/test_meta_schedule_post_order_apply.py @@ -22,6 +22,7 @@ import pytest import tvm +from tvm.meta_schedule.default_config import target import tvm.testing from tvm._ffi import register_func from tvm.error import TVMError @@ -195,6 +196,29 @@ def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: return result +@derived_object +class TrinityDoubleRule(PyScheduleRule): + def _initialize_with_tune_context(self, context: "TuneContext") -> None: + pass + + def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: + if _is_root(sch, block): + return [sch] + new_sch = sch.copy() + i, j = new_sch.get_loops(block=block) + i_0, i_1 = new_sch.split(loop=i, factors=[16, 64]) + j_0, j_1 = new_sch.split(loop=j, factors=[64, 16]) + new_sch.reorder(i_0, j_0, i_1, j_1) + result = [new_sch] + new_sch = sch.copy() + i, j = new_sch.get_loops(block=block) + i_0, i_1 = new_sch.split(loop=i, factors=[2, 512]) + j_0, j_1 = new_sch.split(loop=j, factors=[2, 512]) + new_sch.reorder(i_0, j_0, i_1, j_1) + result.append(new_sch) + return result + + @derived_object class ReorderScheduleRule(PyScheduleRule): def _initialize_with_tune_context(self, context: "TuneContext") -> None: @@ -283,28 +307,6 @@ def test_meta_schedule_post_order_apply_duplicate_matmul(): def test_meta_schedule_post_order_apply_remove_block(): - @derived_object - class TrinityDouble(PyScheduleRule): - def _initialize_with_tune_context(self, context: "TuneContext") -> None: - pass - - def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: - if _is_root(sch, block): - return [sch] - new_sch = sch.copy() - i, j = new_sch.get_loops(block=block) - i_0, i_1 = new_sch.split(loop=i, factors=[16, 64]) - j_0, j_1 = new_sch.split(loop=j, factors=[64, 16]) - new_sch.reorder(i_0, j_0, i_1, j_1) - result = [new_sch] - new_sch = sch.copy() - i, j = new_sch.get_loops(block=block) - i_0, i_1 = new_sch.split(loop=i, factors=[2, 512]) - j_0, j_1 = new_sch.split(loop=j, factors=[2, 512]) - new_sch.reorder(i_0, j_0, i_1, j_1) - result.append(new_sch) - return result - @derived_object class RemoveBlock(PyScheduleRule): def _initialize_with_tune_context(self, context: "TuneContext") -> None: @@ -342,7 +344,7 @@ def correct_trace(a, b, c, d): target=Target("llvm"), task_name="Remove Block Task", space_generator=PostOrderApply(), - sch_rules=[RemoveBlock(), TrinityDouble()], + sch_rules=[RemoveBlock(), TrinityDoubleRule()], ) post_order_apply = context.space_generator schs = post_order_apply.generate_design_space(mod) @@ -385,5 +387,38 @@ def custom_search_space_func(sch: Schedule, _: BlockRV) -> List[Schedule]: assert called +def test_target_blocks_search_space(): + # Test that specific blocks of trinity matmul can be targeted. + def _get_sch(target_blocks=[]): + mod = TrinityMatmul + context = TuneContext( + mod=mod, + target=Target("llvm"), + task_name="Custom Search Space Task", + space_generator=PostOrderApply(target_blocks=target_blocks), + sch_rules=[TrinityDoubleRule()], + ) + post_order_apply = context.space_generator + schs = post_order_apply.generate_design_space(mod) + return schs + + # Start by checking that by default each block has a space generated. + schs = _get_sch() + assert len(schs) == 8 + + # Next check that we can target a specific block and only get its' revelant schedules. + schs = _get_sch(["B"]) + assert len(schs) == 2 + + # Check that extracting two blocks works. + schs = _get_sch(["A", "C"]) + assert len(schs) == 4 + + # Finally check that all blocks can be extracted by name. + schs = _get_sch(["A", "B", "C"]) + assert len(schs) == 8 + + if __name__ == "__main__": - tvm.testing.main() + #tvm.testing.main() + test_target_blocks_search_space() From 831477cfcecf75c5437ea28a85a9369dad6f9f84 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Mon, 8 Aug 2022 16:27:08 -0700 Subject: [PATCH 3/8] Building with packedfunc filter. --- include/tvm/meta_schedule/space_generator.h | 4 +--- .../space_generator/post_order_apply.cc | 21 ++++++++++--------- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/include/tvm/meta_schedule/space_generator.h b/include/tvm/meta_schedule/space_generator.h index 808a55a5ee77..2df040e5d941 100644 --- a/include/tvm/meta_schedule/space_generator.h +++ b/include/tvm/meta_schedule/space_generator.h @@ -132,8 +132,6 @@ class SpaceGenerator : public runtime::ObjectRef { SpaceGenerator() = default; public: - /* A callback function that can be used to filter which blocks have generated spaces. */ - using BlockFilterFunc = runtime::TypedPackedFunc(const tir::BlockNode&)>; /*! * \brief Create a design space generator with customized methods on the python-side. * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. @@ -155,7 +153,7 @@ class SpaceGenerator : public runtime::ObjectRef { * to blocks in post-DFS order. * \return The design space generator created. */ - TVM_DLL static SpaceGenerator PostOrderApply(BlockFilterFunc f_block_filter = nullptr); + TVM_DLL static SpaceGenerator PostOrderApply(runtime::PackedFunc f_block_filter = nullptr); TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(SpaceGenerator, ObjectRef, SpaceGeneratorNode); }; diff --git a/src/meta_schedule/space_generator/post_order_apply.cc b/src/meta_schedule/space_generator/post_order_apply.cc index 56333f02f0a2..43d584baa54f 100644 --- a/src/meta_schedule/space_generator/post_order_apply.cc +++ b/src/meta_schedule/space_generator/post_order_apply.cc @@ -24,7 +24,7 @@ namespace meta_schedule { /*! \brief Collecting all the blocks */ class BlockCollector : public tir::StmtVisitor { public: - static Array Collect(const tir::Schedule& sch, const SpaceGenerator::BlockFilterFunc f_block_filter = nullptr) { // + static Array Collect(const tir::Schedule& sch, const runtime::PackedFunc f_block_filter = nullptr) { // return BlockCollector(sch, f_block_filter).Run(); } @@ -48,7 +48,7 @@ class BlockCollector : public tir::StmtVisitor { return results; } /*! \brief Constructor */ - explicit BlockCollector(const tir::Schedule& sch, const SpaceGenerator::BlockFilterFunc f_block_filter = nullptr) : sch_(sch), f_block_filter_(f_block_filter) {} + explicit BlockCollector(const tir::Schedule& sch, const runtime::PackedFunc f_block_filter = nullptr) : sch_(sch), f_block_filter_(f_block_filter) {} /*! \brief Override the Stmt visiting behaviour */ void VisitStmt_(const tir::BlockNode* block) override { tir::StmtVisitor::VisitStmt_(block); @@ -56,13 +56,14 @@ class BlockCollector : public tir::StmtVisitor { << "Duplicated block name " << block->name_hint << " in function " << func_name_ << " not supported!"; block_names_.insert(block->name_hint); + // If filter function is provided, use it to selectively collect blocks. + // Otherwise collect all blocks. + Bool collect_block = Bool(true); if (f_block_filter_ != nullptr) { - Optional collect_block = f_block_filter_(*block); - if (collect_block.defined() && collect_block) { - blocks_to_collect_.push_back(block->name_hint); - } - } else { + collect_block = f_block_filter_(GetRef(block)); + } + if (collect_block) { blocks_to_collect_.push_back(block->name_hint); } } @@ -70,7 +71,7 @@ class BlockCollector : public tir::StmtVisitor { /*! \brief The schedule to be collected */ const tir::Schedule& sch_; /*! \brief An optional packed func that allows only certain blocks to be collected. */ - const SpaceGenerator::BlockFilterFunc f_block_filter_; + const runtime::PackedFunc f_block_filter_; /*! \brief The set of func name and block name pair */ std::unordered_set block_names_; /* \brief The list of blocks to collect in order */ @@ -92,7 +93,7 @@ class PostOrderApplyNode : public SpaceGeneratorNode { /*! \brief The logging function to use. */ PackedFunc logging_func; /*! \brief Optional block names to target. If not specified all blocks will have spaces generated. */ - SpaceGenerator::BlockFilterFunc f_block_filter_ = nullptr; + runtime::PackedFunc f_block_filter_ = nullptr; void VisitAttrs(tvm::AttrVisitor* v) { // `rand_state_` is not visited @@ -189,7 +190,7 @@ class PostOrderApplyNode : public SpaceGeneratorNode { TVM_DECLARE_FINAL_OBJECT_INFO(PostOrderApplyNode, SpaceGeneratorNode); }; -SpaceGenerator SpaceGenerator::PostOrderApply(SpaceGenerator::BlockFilterFunc f_block_filter) { +SpaceGenerator SpaceGenerator::PostOrderApply(runtime::PackedFunc f_block_filter) { ObjectPtr n = make_object(); n->f_block_filter_ = f_block_filter; return SpaceGenerator(n); From 439d6ed48bacc9058f31748a98154aececebdd5e Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Mon, 8 Aug 2022 16:55:56 -0700 Subject: [PATCH 4/8] Extended tune_tir API to support named blocks. --- .../space_generator/post_order_apply.py | 8 ++---- python/tvm/meta_schedule/tune.py | 26 +++++++++++++++++++ .../space_generator/post_order_apply.cc | 10 ++++--- .../test_meta_schedule_post_order_apply.py | 26 ++++++++++--------- .../unittest/test_meta_schedule_tune_tir.py | 1 + 5 files changed, 50 insertions(+), 21 deletions(-) diff --git a/python/tvm/meta_schedule/space_generator/post_order_apply.py b/python/tvm/meta_schedule/space_generator/post_order_apply.py index 6683059928b4..3fd207d206fd 100644 --- a/python/tvm/meta_schedule/space_generator/post_order_apply.py +++ b/python/tvm/meta_schedule/space_generator/post_order_apply.py @@ -29,12 +29,8 @@ class PostOrderApply(SpaceGenerator): rules to blocks in post-DFS order. """ - def __init__(self, target_blocks=[]): + def __init__(self, filter_fn=None): """Constructor""" - if target_blocks is None: - target_blocks = [] - if not isinstance(target_blocks, (list, tuple)): - target_blocks = [target_blocks] self.__init_handle_by_constructor__( - _ffi_api.SpaceGeneratorPostOrderApply, target_blocks # type: ignore # pylint: disable=no-member + _ffi_api.SpaceGeneratorPostOrderApply, filter_fn # type: ignore # pylint: disable=no-member ) diff --git a/python/tvm/meta_schedule/tune.py b/python/tvm/meta_schedule/tune.py index fbbe24b32e4d..1b313dd8f761 100644 --- a/python/tvm/meta_schedule/tune.py +++ b/python/tvm/meta_schedule/tune.py @@ -24,6 +24,7 @@ from tvm.ir import IRModule from tvm.ir.transform import PassContext +from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply from tvm.runtime import Module, NDArray, vm from tvm.target import Target from tvm.te import Tensor, create_prim_func @@ -364,6 +365,7 @@ def tune_tir( cost_model: Optional[CostModel] = None, measure_callbacks: Optional[List[MeasureCallback]] = None, space: Optional[FnSpaceGenerator] = None, + blocks: Optional[List[str]] = None, sch_rules: Optional[FnScheduleRule] = None, postprocs: Optional[FnPostproc] = None, mutator_probs: Optional[FnMutatorProb] = None, @@ -392,6 +394,21 @@ def tune_tir( The cost model to use. measure_callbacks : Optional[List[MeasureCallback]] The callbacks used during tuning. + space : Optional[FnSpaceGenerator] + The space generator to use. + blocks : Optional[List[str]] + A list of block names to tune. If provided, other blocks + will not be optimized. + sch_rules : Optional[FnScheduleRule] + The search rules to use. + postprocs : Optional[FnPostproc] + The postprocessors to use. + mutator_probs : Optional[FnMutatorProb] + The probability distribution to use different mutators. + task_name : str + The name of the function to extract schedules from. + num_threads : Optional[int] + The number of threads to use Returns ------- @@ -407,6 +424,15 @@ def tune_tir( params=[{"log_dir": log_dir, "logger_name": __name__ + f".task_{task_name}"}], ) + if blocks is not None: + assert space is None, "Only one of blocks and space can be specified." + # Create a filter function to identify named blocks. + def _filter_fn(block, target_names) -> bool: + return block.name_hint in target_names + + # Create a space generator that targets specific blocks. + space = PostOrderApply(filter_fn=lambda block: _filter_fn(block, blocks)) + # pylint: disable=protected-access mod = default_config.mod(mod) target = default_config.target(target) diff --git a/src/meta_schedule/space_generator/post_order_apply.cc b/src/meta_schedule/space_generator/post_order_apply.cc index 43d584baa54f..51dea2c2fe90 100644 --- a/src/meta_schedule/space_generator/post_order_apply.cc +++ b/src/meta_schedule/space_generator/post_order_apply.cc @@ -24,7 +24,8 @@ namespace meta_schedule { /*! \brief Collecting all the blocks */ class BlockCollector : public tir::StmtVisitor { public: - static Array Collect(const tir::Schedule& sch, const runtime::PackedFunc f_block_filter = nullptr) { // + static Array Collect(const tir::Schedule& sch, + const runtime::PackedFunc f_block_filter = nullptr) { // return BlockCollector(sch, f_block_filter).Run(); } @@ -48,7 +49,9 @@ class BlockCollector : public tir::StmtVisitor { return results; } /*! \brief Constructor */ - explicit BlockCollector(const tir::Schedule& sch, const runtime::PackedFunc f_block_filter = nullptr) : sch_(sch), f_block_filter_(f_block_filter) {} + explicit BlockCollector(const tir::Schedule& sch, + const runtime::PackedFunc f_block_filter = nullptr) + : sch_(sch), f_block_filter_(f_block_filter) {} /*! \brief Override the Stmt visiting behaviour */ void VisitStmt_(const tir::BlockNode* block) override { tir::StmtVisitor::VisitStmt_(block); @@ -92,7 +95,8 @@ class PostOrderApplyNode : public SpaceGeneratorNode { Array sch_rules_{nullptr}; /*! \brief The logging function to use. */ PackedFunc logging_func; - /*! \brief Optional block names to target. If not specified all blocks will have spaces generated. */ + /*! \brief Optional block names to target. If not specified all blocks will have spaces generated. + */ runtime::PackedFunc f_block_filter_ = nullptr; void VisitAttrs(tvm::AttrVisitor* v) { diff --git a/tests/python/unittest/test_meta_schedule_post_order_apply.py b/tests/python/unittest/test_meta_schedule_post_order_apply.py index fa4295b03631..54b7bf0913ac 100644 --- a/tests/python/unittest/test_meta_schedule_post_order_apply.py +++ b/tests/python/unittest/test_meta_schedule_post_order_apply.py @@ -216,7 +216,7 @@ def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: j_0, j_1 = new_sch.split(loop=j, factors=[2, 512]) new_sch.reorder(i_0, j_0, i_1, j_1) result.append(new_sch) - return result + return result @derived_object @@ -389,36 +389,38 @@ def custom_search_space_func(sch: Schedule, _: BlockRV) -> List[Schedule]: def test_target_blocks_search_space(): # Test that specific blocks of trinity matmul can be targeted. - def _get_sch(target_blocks=[]): + def filter_fn(block, target_names) -> bool: + return block.name_hint in target_names + + def _get_sch(filter_fn): mod = TrinityMatmul context = TuneContext( mod=mod, target=Target("llvm"), task_name="Custom Search Space Task", - space_generator=PostOrderApply(target_blocks=target_blocks), + space_generator=PostOrderApply(filter_fn=filter_fn), sch_rules=[TrinityDoubleRule()], ) post_order_apply = context.space_generator - schs = post_order_apply.generate_design_space(mod) + schs = post_order_apply.generate_design_space(mod) return schs # Start by checking that by default each block has a space generated. - schs = _get_sch() + schs = _get_sch(None) assert len(schs) == 8 # Next check that we can target a specific block and only get its' revelant schedules. - schs = _get_sch(["B"]) + schs = _get_sch(lambda block: filter_fn(block, ["B"])) assert len(schs) == 2 - # Check that extracting two blocks works. - schs = _get_sch(["A", "C"]) + ## Check that extracting two blocks works. + schs = _get_sch(lambda block: filter_fn(block, ["A", "C"])) assert len(schs) == 4 - # Finally check that all blocks can be extracted by name. - schs = _get_sch(["A", "B", "C"]) + ## Finally check that all blocks can be extracted by name. + schs = _get_sch(lambda block: filter_fn(block, ["A", "B", "C"])) assert len(schs) == 8 if __name__ == "__main__": - #tvm.testing.main() - test_target_blocks_search_space() + tvm.testing.main() diff --git a/tests/python/unittest/test_meta_schedule_tune_tir.py b/tests/python/unittest/test_meta_schedule_tune_tir.py index 0e8c205230e6..19c5b892e691 100644 --- a/tests/python/unittest/test_meta_schedule_tune_tir.py +++ b/tests/python/unittest/test_meta_schedule_tune_tir.py @@ -66,6 +66,7 @@ def test_tune_matmul_cpu(): max_trials_global=32, ), work_dir=work_dir, + blocks=["update"], ) if sch is None: print("No valid schedule found!") From 20797ccb4f62a10fc65c128ada3fc52f57f25239 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Mon, 8 Aug 2022 17:00:28 -0700 Subject: [PATCH 5/8] Remove accidental import. --- tests/python/unittest/test_meta_schedule_post_order_apply.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/python/unittest/test_meta_schedule_post_order_apply.py b/tests/python/unittest/test_meta_schedule_post_order_apply.py index 54b7bf0913ac..52281c9f3487 100644 --- a/tests/python/unittest/test_meta_schedule_post_order_apply.py +++ b/tests/python/unittest/test_meta_schedule_post_order_apply.py @@ -22,7 +22,6 @@ import pytest import tvm -from tvm.meta_schedule.default_config import target import tvm.testing from tvm._ffi import register_func from tvm.error import TVMError From 0cdbce459f0b9c3c648904745ccae9a31a55531e Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Mon, 8 Aug 2022 21:25:17 -0700 Subject: [PATCH 6/8] Improve integration test. --- python/tvm/meta_schedule/tune.py | 7 +-- .../unittest/test_meta_schedule_tune_tir.py | 47 ++++++++++++++++--- 2 files changed, 45 insertions(+), 9 deletions(-) diff --git a/python/tvm/meta_schedule/tune.py b/python/tvm/meta_schedule/tune.py index 1b313dd8f761..c8101156e881 100644 --- a/python/tvm/meta_schedule/tune.py +++ b/python/tvm/meta_schedule/tune.py @@ -397,8 +397,9 @@ def tune_tir( space : Optional[FnSpaceGenerator] The space generator to use. blocks : Optional[List[str]] - A list of block names to tune. If provided, other blocks - will not be optimized. + A list of block names specifying blocks to be tuned. Note that if + the list is not None, blocks outside this list will not be tuned. + Only one of this argument and space may be provided. sch_rules : Optional[FnScheduleRule] The search rules to use. postprocs : Optional[FnPostproc] @@ -425,7 +426,7 @@ def tune_tir( ) if blocks is not None: - assert space is None, "Only one of blocks and space can be specified." + assert space is None, "Can not specify blocks to tune when a search space is given." # Create a filter function to identify named blocks. def _filter_fn(block, target_names) -> bool: return block.name_hint in target_names diff --git a/tests/python/unittest/test_meta_schedule_tune_tir.py b/tests/python/unittest/test_meta_schedule_tune_tir.py index 19c5b892e691..c75d1d0ef3d5 100644 --- a/tests/python/unittest/test_meta_schedule_tune_tir.py +++ b/tests/python/unittest/test_meta_schedule_tune_tir.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=missing-docstring +# pylint: disable=missing-docstring,no-member,invalid-name,unused-variable import logging import tempfile import numpy as np @@ -34,9 +34,6 @@ logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) -# pylint: disable=no-member,invalid-name,unused-variable - - @T.prim_func def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [128, 128]) @@ -50,7 +47,19 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] -# pylint: enable=no-member,invalid-name,unused-variable +@T.prim_func +def two_step(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.alloc_buffer((1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + for i, j in T.grid(1024, 1024): + with T.block("A"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(1024, 1024): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 3.0 @pytest.mark.skip("Integration test") @@ -66,7 +75,6 @@ def test_tune_matmul_cpu(): max_trials_global=32, ), work_dir=work_dir, - blocks=["update"], ) if sch is None: print("No valid schedule found!") @@ -75,6 +83,32 @@ def test_tune_matmul_cpu(): print(sch.trace) +@pytest.mark.skip("Integration test") +def test_tune_block_cpu(): + with tempfile.TemporaryDirectory() as work_dir: + sch: Schedule = tune_tir( + mod=two_step, + target=Target("llvm --num-cores=16"), + config=TuneConfig( + strategy="replay_trace", + num_trials_per_iter=32, + max_trials_per_task=32, + max_trials_global=32, + ), + work_dir=work_dir, + blocks=["B"], + ) + if sch is None: + print("No valid schedule found!") + else: + # Since only block B was tuned, we should now be able + # to manually inline block A without an error. + block_a = sch.get_block("A") + sch.compute_inline(block=block_a) + print(sch.mod.script()) + print(sch.trace) + + @pytest.mark.skip("Integration test") def test_tune_matmul_cuda(): with tempfile.TemporaryDirectory() as work_dir: @@ -142,3 +176,4 @@ def f_timer(rt_mod, dev, input_data): test_tune_matmul_cpu() test_tune_matmul_cuda() test_tune_run_module_via_rpc() + test_tune_block_cpu() From f288c36983863331b6b38ecfea1065637505b4bf Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Mon, 8 Aug 2022 22:11:25 -0700 Subject: [PATCH 7/8] Change names for more consistency. --- .../space_generator/post_order_apply.py | 12 ++++++++++-- python/tvm/meta_schedule/tune.py | 4 ++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/python/tvm/meta_schedule/space_generator/post_order_apply.py b/python/tvm/meta_schedule/space_generator/post_order_apply.py index 3fd207d206fd..6e2a2c52b1a1 100644 --- a/python/tvm/meta_schedule/space_generator/post_order_apply.py +++ b/python/tvm/meta_schedule/space_generator/post_order_apply.py @@ -27,10 +27,18 @@ class PostOrderApply(SpaceGenerator): """ PostOrderApply is the design space generator that generates design spaces by applying schedule rules to blocks in post-DFS order. + + Parameters + ---------- + f_block_filter : Optional[function] + An optional callback function that is used to filter which blocks have schedules generated + for them. The function should take in a block and return True if a schedule should + be generated or False if that block should be skipped. If no function is provided + all blocks will have schedules generated. """ - def __init__(self, filter_fn=None): + def __init__(self, f_block_filter=None): """Constructor""" self.__init_handle_by_constructor__( - _ffi_api.SpaceGeneratorPostOrderApply, filter_fn # type: ignore # pylint: disable=no-member + _ffi_api.SpaceGeneratorPostOrderApply, f_block_filter # type: ignore # pylint: disable=no-member ) diff --git a/python/tvm/meta_schedule/tune.py b/python/tvm/meta_schedule/tune.py index c8101156e881..447fb56637ef 100644 --- a/python/tvm/meta_schedule/tune.py +++ b/python/tvm/meta_schedule/tune.py @@ -428,11 +428,11 @@ def tune_tir( if blocks is not None: assert space is None, "Can not specify blocks to tune when a search space is given." # Create a filter function to identify named blocks. - def _filter_fn(block, target_names) -> bool: + def _f_block_filter(block, target_names) -> bool: return block.name_hint in target_names # Create a space generator that targets specific blocks. - space = PostOrderApply(filter_fn=lambda block: _filter_fn(block, blocks)) + space = PostOrderApply(f_block_filter=lambda block: _f_block_filter(block, blocks)) # pylint: disable=protected-access mod = default_config.mod(mod) From 90f648032ede790133dc18c39176e2742e763cdc Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Tue, 9 Aug 2022 09:17:47 -0700 Subject: [PATCH 8/8] Update integration test. --- .../test_meta_schedule_post_order_apply.py | 2 +- .../unittest/test_meta_schedule_tune_tir.py | 31 ++++++++++++------- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/tests/python/unittest/test_meta_schedule_post_order_apply.py b/tests/python/unittest/test_meta_schedule_post_order_apply.py index 52281c9f3487..97a49602fb26 100644 --- a/tests/python/unittest/test_meta_schedule_post_order_apply.py +++ b/tests/python/unittest/test_meta_schedule_post_order_apply.py @@ -397,7 +397,7 @@ def _get_sch(filter_fn): mod=mod, target=Target("llvm"), task_name="Custom Search Space Task", - space_generator=PostOrderApply(filter_fn=filter_fn), + space_generator=PostOrderApply(f_block_filter=filter_fn), sch_rules=[TrinityDoubleRule()], ) post_order_apply = context.space_generator diff --git a/tests/python/unittest/test_meta_schedule_tune_tir.py b/tests/python/unittest/test_meta_schedule_tune_tir.py index c75d1d0ef3d5..6ab5f9b8c5c4 100644 --- a/tests/python/unittest/test_meta_schedule_tune_tir.py +++ b/tests/python/unittest/test_meta_schedule_tune_tir.py @@ -23,12 +23,14 @@ import tvm from tvm import meta_schedule as ms -from tvm.meta_schedule import TuneConfig, tune_tir +from tvm.meta_schedule import TuneContext, TuneConfig, tune_tir from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc from tvm.meta_schedule.testing.local_rpc import LocalRPC +from tvm.meta_schedule.schedule_rule import PyScheduleRule +from tvm.meta_schedule.utils import derived_object from tvm.script import tir as T from tvm.target import Target -from tvm.tir import Schedule +from tvm.tir.schedule import BlockRV, Schedule logging.basicConfig() logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) @@ -85,6 +87,18 @@ def test_tune_matmul_cpu(): @pytest.mark.skip("Integration test") def test_tune_block_cpu(): + @derived_object + class RemoveBlock(PyScheduleRule): + def _initialize_with_tune_context(self, context: TuneContext) -> None: + pass + + def apply(self, sch: Schedule, block: BlockRV): + if sch.get(block).name_hint == "root": + return [sch] + sch = sch.copy() + sch.compute_inline(block) + return [sch] + with tempfile.TemporaryDirectory() as work_dir: sch: Schedule = tune_tir( mod=two_step, @@ -96,17 +110,10 @@ def test_tune_block_cpu(): max_trials_global=32, ), work_dir=work_dir, - blocks=["B"], + blocks=["A"], + sch_rules=lambda *args: [RemoveBlock()], ) - if sch is None: - print("No valid schedule found!") - else: - # Since only block B was tuned, we should now be able - # to manually inline block A without an error. - block_a = sch.get_block("A") - sch.compute_inline(block=block_a) - print(sch.mod.script()) - print(sch.trace) + assert sch is not None @pytest.mark.skip("Integration test")