diff --git a/include/tvm/meta_schedule/space_generator.h b/include/tvm/meta_schedule/space_generator.h index 2df040e5d941..2c1b2d4e4d7d 100644 --- a/include/tvm/meta_schedule/space_generator.h +++ b/include/tvm/meta_schedule/space_generator.h @@ -141,7 +141,14 @@ class SpaceGenerator : public runtime::ObjectRef { TVM_DLL static SpaceGenerator PySpaceGenerator( PySpaceGeneratorNode::FInitializeWithTuneContext f_initialize_with_tune_context, PySpaceGeneratorNode::FGenerateDesignSpace f_generate_design_space); - + /*! + * \brief Create a design space generator with customized schedule function. + * \param schedule_fn The schedule function, which can have the following signatures: + * 1) void(Schedule) + * 2) Schedule(Schedule) + * 3) Array(Schedule) + */ + TVM_DLL static SpaceGenerator ScheduleFn(PackedFunc schedule_fn); /*! * \brief Create a design space generator that is union of multiple design space generators. * \param space_generators An array of design space generators to be unioned. diff --git a/python/tvm/meta_schedule/space_generator/__init__.py b/python/tvm/meta_schedule/space_generator/__init__.py index 007fa6da4559..d2039c4511c9 100644 --- a/python/tvm/meta_schedule/space_generator/__init__.py +++ b/python/tvm/meta_schedule/space_generator/__init__.py @@ -20,6 +20,6 @@ space for generation of measure candidates. """ from .post_order_apply import PostOrderApply -from .schedule_fn import SCH_FN_TYPE, ScheduleFn -from .space_generator import PySpaceGenerator, SpaceGenerator +from .schedule_fn import ScheduleFn +from .space_generator import PySpaceGenerator, ScheduleFnType, SpaceGenerator from .space_generator_union import SpaceGeneratorUnion diff --git a/python/tvm/meta_schedule/space_generator/schedule_fn.py b/python/tvm/meta_schedule/space_generator/schedule_fn.py index 97498bcbf59d..d6b063dcb263 100644 --- a/python/tvm/meta_schedule/space_generator/schedule_fn.py +++ b/python/tvm/meta_schedule/space_generator/schedule_fn.py @@ -14,78 +14,34 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -""" -Meta schedule design space generators that generates design -space via a schedule function. -""" -from typing import TYPE_CHECKING, Callable, List, Union +"""Union of meta Schedule design space generators.""" +from tvm._ffi import register_object -from tvm.ir import IRModule -from tvm.ir.container import Array -from tvm.meta_schedule.utils import derived_object -from tvm.tir.schedule import Schedule +from .. import _ffi_api +from .space_generator import SpaceGenerator -from .space_generator import PySpaceGenerator -if TYPE_CHECKING: - from ..tune_context import TuneContext +@register_object("meta_schedule.ScheduleFn") +class ScheduleFn(SpaceGenerator): + """Create a design space generator with customized schedule function. + The schedule function can have the following signatures: + - 1) [Schedule] -> None + - 2) [Schedule] -> Schedule + - 3) [Schedule] -> List[Schedule] + """ -SCH_FN_TYPE = Union[ # pylint: disable=invalid-name - Callable[[Schedule], None], # No output - Callable[[Schedule], Schedule], # Single output - Callable[[Schedule], List[Schedule]], # Multiple outputs -] - - -@derived_object -class ScheduleFn(PySpaceGenerator): - """A design space generator with design spaces specified by a schedule function.""" - - def __init__(self, sch_fn: SCH_FN_TYPE): + def __init__(self, sch_fn: SpaceGenerator.ScheduleFnType): """Constructor. Parameters ---------- - sch_fn : SCH_FN_TYPE - The schedule function. - """ - super().__init__() - self.sch_fn = sch_fn - - def _initialize_with_tune_context(self, context: "TuneContext") -> None: - """Initialize the design space generator with tuning context. - - Parameters - ---------- - context : TuneContext - The tuning context for initializing the design space generator. - """ - - def generate_design_space(self, mod: IRModule) -> List[Schedule]: - """Generate design spaces given a module. - - Parameters - ---------- - mod : IRModule - The module used for design space generation. - - Returns - ------- - design_spaces : List[Schedule] - The generated design spaces, i.e., schedules. + sch_fn : SpaceGenerator.ScheduleFnType + The schedule function, which can have the following signatures: + - 1) [Schedule] -> None + - 2) [Schedule] -> Schedule + - 3) [Schedule] -> List[Schedule] """ - sch = Schedule(mod) # Make sure the schedule is traced - result = self.sch_fn(sch) # Call the schedule function - if result is None: # Case 1. No output - return [sch] - if isinstance(result, Schedule): # Case 2. Single output - return [result] - if isinstance(result, (list, tuple, Array)): # Case 3. Multiple outputs - for ret in result: # enumerate the outputs - if not isinstance(ret, Schedule): - raise TypeError( - "Wrong type of element in the list, expected Schedule got " - + f"'{type(ret)}': {ret}" - ) - return result - raise TypeError(f"Unexpected return type {type(result)}: {result}") + self.__init_handle_by_constructor__( + _ffi_api.SpaceGeneratorScheduleFn, # type: ignore # pylint: disable=no-member + sch_fn, + ) diff --git a/python/tvm/meta_schedule/space_generator/space_generator.py b/python/tvm/meta_schedule/space_generator/space_generator.py index eb999de49585..74c29b4de0dd 100644 --- a/python/tvm/meta_schedule/space_generator/space_generator.py +++ b/python/tvm/meta_schedule/space_generator/space_generator.py @@ -18,7 +18,7 @@ Meta Schedule design space generators that generates design space for generation of measure candidates. """ -from typing import TYPE_CHECKING, Callable, List, Optional +from typing import TYPE_CHECKING, Callable, List, Optional, Union from tvm._ffi import register_object from tvm.ir import IRModule @@ -35,6 +35,12 @@ class SpaceGenerator(Object): """The abstract design space generator interface.""" + ScheduleFnType = Union[ + Callable[[Schedule], None], # No output + Callable[[Schedule], Schedule], # Single output + Callable[[Schedule], List[Schedule]], # Multiple outputs + ] + def _initialize_with_tune_context(self, context: "TuneContext") -> None: """Initialize the design space generator with tuning context. @@ -63,6 +69,9 @@ def generate_design_space(self, mod: IRModule) -> List[Schedule]: return _ffi_api.SpaceGeneratorGenerateDesignSpace(self, mod) # type: ignore # pylint: disable=no-member +ScheduleFnType = SpaceGenerator.ScheduleFnType + + @register_object("meta_schedule.PySpaceGenerator") class _PySpaceGenerator(SpaceGenerator): """ diff --git a/python/tvm/meta_schedule/tune_context.py b/python/tvm/meta_schedule/tune_context.py index d39ad1738ec8..17acad8d4a57 100644 --- a/python/tvm/meta_schedule/tune_context.py +++ b/python/tvm/meta_schedule/tune_context.py @@ -36,7 +36,7 @@ from .runner import RunnerResult from .schedule_rule import ScheduleRule from .search_strategy import MeasureCandidate, SearchStrategy - from .space_generator import SCH_FN_TYPE, ScheduleFn, SpaceGenerator + from .space_generator import ScheduleFn, ScheduleFnType, SpaceGenerator from .tune import TuneConfig @@ -55,7 +55,7 @@ class TuneContext(Object): The workload to be optimized. target : Optional[Target] = None The target to be optimized for. - space_generator : Union[None, SCH_FN_TYPE, SpaceGenerator] = None + space_generator : Union[None, ScheduleFnType, SpaceGenerator] = None The design space generator. search_strategy : Union[None, TuneConfig, SearchStrategy] = None The search strategy. @@ -108,7 +108,7 @@ def __init__( mod: Optional[IRModule] = None, *, target: Optional[Target] = None, - space_generator: Union[None, "SCH_FN_TYPE", "ScheduleFn", "SpaceGenerator"] = None, + space_generator: Union[None, "ScheduleFnType", "ScheduleFn", "SpaceGenerator"] = None, search_strategy: Union[None, "SearchStrategy", "TuneConfig"] = None, sch_rules: Union[None, str, List["ScheduleRule"]] = None, postprocs: Union[None, str, List["Postproc"]] = None, diff --git a/src/meta_schedule/space_generator/post_order_apply.cc b/src/meta_schedule/space_generator/post_order_apply.cc index 51dea2c2fe90..eab084f8978f 100644 --- a/src/meta_schedule/space_generator/post_order_apply.cc +++ b/src/meta_schedule/space_generator/post_order_apply.cc @@ -112,10 +112,10 @@ class PostOrderApplyNode : public SpaceGeneratorNode { this->logging_func = context->logging_func; } - Array GenerateDesignSpace(const IRModule& mod_) final { + Array GenerateDesignSpace(const IRModule& mod) final { using ScheduleAndUnvisitedBlocks = std::pair>; tir::Schedule sch = tir::Schedule::Traced( - /*mod=*/mod_, + /*mod=*/mod, /*rand_state=*/ForkSeed(&this->rand_state_), /*debug_mode=*/0, /*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail); diff --git a/src/meta_schedule/space_generator/schedule_fn.cc b/src/meta_schedule/space_generator/schedule_fn.cc new file mode 100644 index 000000000000..70559fbcf1fb --- /dev/null +++ b/src/meta_schedule/space_generator/schedule_fn.cc @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +/*! \brief The union of design space generators. */ +class ScheduleFnNode : public SpaceGeneratorNode { + public: + /*! \brief The random state. -1 means using random number. */ + TRandState rand_state_ = -1; + /*! \brief The schedule function. */ + runtime::PackedFunc schedule_fn_; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `schedule_fn_` is not visited. + } + + void InitializeWithTuneContext(const TuneContext& context) final { + this->rand_state_ = ForkSeed(&context->rand_state); + } + + Array GenerateDesignSpace(const IRModule& mod) final { + tir::Schedule sch = tir::Schedule::Traced( + /*mod=*/mod, + /*rand_state=*/ForkSeed(&this->rand_state_), + /*debug_mode=*/0, + /*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail); + runtime::TVMRetValue rv; + rv = this->schedule_fn_(sch); + if (rv.type_code() == kTVMNullptr) { + return {sch}; + } + ObjectRef obj = rv; + if (const auto* sch = obj.as()) { + return {GetRef(sch)}; + } + if (const auto* arr = obj.as()) { + Array result; + result.reserve(arr->size()); + for (const ObjectRef& obj : *arr) { + if (const auto* sch = obj.as()) { + result.push_back(GetRef(sch)); + } else { + LOG(FATAL) << "TypeError: Expect return type of ScheduleFn to be None, Schedule or " + "List[Schedule], but got: " + << obj->GetTypeKey(); + } + } + return result; + } + LOG(FATAL) << "TypeError: Expect return type of ScheduleFn to be None, Schedule or " + "List[Schedule], but got: " + << obj->GetTypeKey(); + throw; + } + + static constexpr const char* _type_key = "meta_schedule.ScheduleFn"; + TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleFnNode, SpaceGeneratorNode); +}; + +SpaceGenerator SpaceGenerator::ScheduleFn(PackedFunc schedule_fn) { + ObjectPtr n = make_object(); + n->schedule_fn_ = std::move(schedule_fn); + return SpaceGenerator(n); +} + +TVM_REGISTER_NODE_TYPE(ScheduleFnNode); +TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorScheduleFn") + .set_body_typed(SpaceGenerator::ScheduleFn); + +} // namespace meta_schedule +} // namespace tvm