diff --git a/include/tvm/meta_schedule/builder.h b/include/tvm/meta_schedule/builder.h index 9186c9d039e0..19358552df10 100644 --- a/include/tvm/meta_schedule/builder.h +++ b/include/tvm/meta_schedule/builder.h @@ -25,7 +25,7 @@ namespace tvm { namespace meta_schedule { -/*! \brief The builder's input. */ +/*! \brief The builder's input, containing an IRModule and the target. */ class BuilderInputNode : public runtime::Object { public: /*! \brief The IRModule to be built. */ @@ -57,7 +57,7 @@ class BuilderInput : public runtime::ObjectRef { TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BuilderInput, runtime::ObjectRef, BuilderInputNode); }; -/*! \brief The builder's output. */ +/*! \brief The builder's output, containing the artifact path or error message if any. */ class BuilderResultNode : public runtime::Object { public: /*! \brief The path to the built artifact. */ diff --git a/include/tvm/meta_schedule/runner.h b/include/tvm/meta_schedule/runner.h new file mode 100644 index 000000000000..36d07024559d --- /dev/null +++ b/include/tvm/meta_schedule/runner.h @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_META_SCHEDULE_RUNNER_H_ +#define TVM_META_SCHEDULE_RUNNER_H_ + +#include + +namespace tvm { +namespace meta_schedule { + +/*! \brief Runner's output containing measurement result of MeasureCandidate or error msg if any. */ +class RunnerResultNode : public runtime::Object { + public: + /*! \brief The run time in seconds. If not None, error_msg should be None. */ + Optional> run_secs; + /*! \brief The error message, if any. If not None, run_secs should be None. */ + Optional error_msg; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("run_secs", &run_secs); + v->Visit("error_msg", &error_msg); + } + + static constexpr const char* _type_key = "meta_schedule.RunnerResult"; + TVM_DECLARE_FINAL_OBJECT_INFO(RunnerResultNode, runtime::Object); +}; + +/*! + * \brief Managed reference to RunnerResultNode + * \sa RunnerResultNode + */ +class RunnerResult : public runtime::ObjectRef { + public: + /*! + * \brief Constructor for RunnerResult. + * \param run_secs The run time in seconds. + * \param error_msg The error message, if any. + */ + TVM_DLL explicit RunnerResult(Optional> run_secs, Optional error_msg); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(RunnerResult, runtime::ObjectRef, RunnerResultNode); +}; + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_RUNNER_H_ diff --git a/include/tvm/meta_schedule/search_strategy.h b/include/tvm/meta_schedule/search_strategy.h new file mode 100644 index 000000000000..941dae4336e1 --- /dev/null +++ b/include/tvm/meta_schedule/search_strategy.h @@ -0,0 +1,246 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_META_SCHEDULE_SEARCH_STRATEGY_H_ +#define TVM_META_SCHEDULE_SEARCH_STRATEGY_H_ + +#include +#include +#include + +namespace tvm { +namespace meta_schedule { + +// Forward declaration +class TuneContext; + +/*! \brief The schedule (with input shapes) to be measured. */ +class MeasureCandidateNode : public runtime::Object { + public: + /*! \brief The schedule for measurement. */ + tir::Schedule sch; + /*! \brief The argument information, e.g., (shape, dtype) for tensors. */ + Array args_info; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("sch", &sch); + v->Visit("args_info", &args_info); + } + + static constexpr const char* _type_key = "meta_schedule.MeasureCandidate"; + TVM_DECLARE_FINAL_OBJECT_INFO(MeasureCandidateNode, Object); +}; + +/*! + * \brief Managed reference to MeasureCandidateNode. + * \sa MeasureCandidateNode + */ +class MeasureCandidate : public runtime::ObjectRef { + public: + /*! + * \brief Constructor of MeasureCandidate. + * \param sch The schedule for measurement. + * \param args_info The argument information, e.g., (shape, dtype) for tensors. + */ + TVM_DLL MeasureCandidate(tir::Schedule sch, Array args_info); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(MeasureCandidate, ObjectRef, MeasureCandidateNode); +}; + +/*! + * \brief The search strategy for measure candidates generation. + * \note The relationship between SearchStrategy and other classes are as follows: + ┌──────────────────────────────────────────────────────────────┐ + ┌──┴───────────────────────────────────────────────────────────┐ │ +┌──┴────────────────── Tune Context ───────────────────────────┐ │ │ +│ ┌─────────────────────┐ │ │ │ +│ │ │ Generate │ │ │ +│ │ Space Generator ├──────────────┐ │ │ │ +│ │ │ │ │ │ │ +│ └─────────────────────┘ ▼ │ │ │ +│ Design Space │ │ │ +│ ┌─────────────────────┐ │ │ │ │ +│ Generate │ │ Pretuning │ │ │ │ +│ ┌───────────┤ Search Strategy │◄─────────────┘ │ │ │ +│ │ │ │ │ ├──┘ +│ │ └─────────────────────┘ ├──┘ +└────┼─────────────────────────────────────────────────────────┘ + │ + │ +┌────┼──────────────── Managed By Task Scheduler ─────────────────────┐ +│ │ ┌───────────┐ │ +│ │ Send to │ │ Send to │ +│ ▼ ┌─────────────►│ Builder ├──────────┐ │ +│ Measure Candidate │ Builder │ │ Runner │ │ +│ │ │ └───────────┘ │ │ +│ │ ┌────────────┴────────┐ │ │ +│ │ │ │ ┌───────────┐ │ │ +│ └────►│ Task Scheduler │ │ │ │ │ +│ │ │ │ Runner │◄─────────┘ │ +│ └─────────────────────┘ │ │ │ +│ ▲ └─────┬─────┘ │ +│ │ │ │ +│ └─── Runner Future ◄────┘ │ +└─────────────────────────────────────────────────────────────────────┘ +*/ +class SearchStrategyNode : public runtime::Object { + public: + /*! \brief Virtual destructor */ + virtual ~SearchStrategyNode() = default; + + /*! + * \brief Initialize the search strategy with tuning context. + * \param tune_context The tuning context for initialization. + * \note This method is supposed to be called only once before every other method. + */ + virtual void InitializeWithTuneContext(const TuneContext& tune_context) = 0; + + /*! + * \brief Pre-tuning for the search strategy. + * \param design_spaces The design spaces for pre-tuning. + * \note Pre-tuning is supposed to be called before the tuning process and after the + * initialization. Because the search strategy is stateful, we can always call pretuning + * and reset the search strategy. + */ + virtual void PreTuning(const Array& design_spaces) = 0; + + /*! + * \brief Post-tuning for the search strategy. + * \note Post-tuning is supposed to be called after the tuning process and before we reset the + * search strategy with another pre-tuning. Post-tuning can be empty. + */ + virtual void PostTuning() = 0; + + /*! + * \brief Generate measure candidates from design spaces for measurement. + * \return The measure candidates generated, nullptr if finished. + */ + virtual Optional> GenerateMeasureCandidates() = 0; + + /*! + * \brief Update the search strategy with measurement results. + * \param results The measurement results from the runner. + */ + virtual void NotifyRunnerResults(const Array& results) = 0; + + static constexpr const char* _type_key = "meta_schedule.SearchStrategy"; + TVM_DECLARE_BASE_OBJECT_INFO(SearchStrategyNode, Object); +}; + +/*! \brief The python side customizable class for measure candidate generation */ +class PySearchStrategyNode : public SearchStrategyNode { + public: + /*! + * \brief The function type of `InitializeWithTuneContext` method. + * \param tune_context The tuning context for initialization. + */ + using FInitializeWithTuneContext = runtime::TypedPackedFunc; + /*! + * \brief The function type of `PreTuning` method. + * \param design_spaces The design spaces for pre-tuning. + */ + using FPreTuning = runtime::TypedPackedFunc&)>; + /*! \brief The function type of `PostTuning` method. */ + using FPostTuning = runtime::TypedPackedFunc; + /*! + * \brief The function type of `GenerateMeasureCandidates` method. + * \return The measure candidates generated, nullptr if finished. + */ + using FGenerateMeasureCandidates = runtime::TypedPackedFunc>()>; + /*! + * \brief The function type of `NotifyRunnerResults` method. + * \param results The measurement results from the runner. + */ + using FNotifyRunnerResults = runtime::TypedPackedFunc&)>; + + /*! \brief The packed function to the `InitializeWithTuneContext` method. */ + FInitializeWithTuneContext f_initialize_with_tune_context; + /*! \brief The packed function to the `PreTuning` method. */ + FPreTuning f_pre_tuning; + /*! \brief The packed function to the `PostTuning` method. */ + FPostTuning f_post_tuning; + /*! \brief The packed function to the `GenerateMeasureCandidates` method. */ + FGenerateMeasureCandidates f_generate_measure_candidates; + /*! \brief The packed function to the `NotifyRunnerResults` method. */ + FNotifyRunnerResults f_notify_runner_results; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `f_initialize_with_tune_context` is not visited + // `f_pre_tuning` is not visited + // `f_post_tuning` is not visited + // `f_generate_measure_candidates` is not visited + // `f_notify_runner_results` is not visited + } + + void InitializeWithTuneContext(const TuneContext& context) final { + this->f_initialize_with_tune_context(context); + } + + void PreTuning(const Array& design_spaces) final { + this->f_pre_tuning(design_spaces); + } + + void PostTuning() final { this->f_post_tuning(); } + + Optional> GenerateMeasureCandidates() final { + return this->f_generate_measure_candidates(); + } + + void NotifyRunnerResults(const Array& results) final { + this->f_notify_runner_results(results); + } + + static constexpr const char* _type_key = "meta_schedule.PySearchStrategy"; + TVM_DECLARE_FINAL_OBJECT_INFO(PySearchStrategyNode, SearchStrategyNode); +}; + +/*! + * \brief Managed reference to SearchStrategyNode. + * \sa SearchStrategyNode + */ +class SearchStrategy : public runtime::ObjectRef { + public: + /*! + * \brief Create a search strategy with customized methods on the python-side. + * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. + * \param f_pre_tuning The packed function of `PreTuning`. + * \param f_post_tuning The packed function of `PostTuning`. + * \param f_generate_measure_candidates The packed function of `GenerateMeasureCandidates`. + * \param f_notify_runner_results The packed function of `NotifyRunnerResults`. + * \return The search strategy created. + */ + TVM_DLL static SearchStrategy PySearchStrategy( + PySearchStrategyNode::FInitializeWithTuneContext f_initialize_with_tune_context, // + PySearchStrategyNode::FPreTuning f_pre_tuning, // + PySearchStrategyNode::FPostTuning f_post_tuning, // + PySearchStrategyNode::FGenerateMeasureCandidates f_generate_measure_candidates, // + PySearchStrategyNode::FNotifyRunnerResults f_notify_runner_results); + + /*! + * \brief Constructor of replay trace search strategy. + * \param num_trials_per_iter The number of trials per iteration, i.e., the batch size. + * \param num_trials_total The total number of trials for trace replaying. + */ + TVM_DLL static SearchStrategy ReplayTrace(int num_trials_per_iter, int num_trials_total); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchStrategy, ObjectRef, SearchStrategyNode); +}; + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_SEARCH_STRATEGY_H_ diff --git a/include/tvm/meta_schedule/space_generator.h b/include/tvm/meta_schedule/space_generator.h index 9528be2a85ad..3dc181e05d8a 100644 --- a/include/tvm/meta_schedule/space_generator.h +++ b/include/tvm/meta_schedule/space_generator.h @@ -28,7 +28,42 @@ namespace meta_schedule { // Forward declaration class TuneContext; -/*! \brief The abstract class for design space generation. */ +/*! + * \brief The abstract class for design space generation. + * \note The relationship between SpaceGenerator and other classes are as follows: + ┌──────────────────────────────────────────────────────────────┐ + ┌──┴───────────────────────────────────────────────────────────┐ │ +┌──┴────────────────── Tune Context ───────────────────────────┐ │ │ +│ ┌─────────────────────┐ │ │ │ +│ │ │ Generate │ │ │ +│ │ Space Generator ├──────────────┐ │ │ │ +│ │ │ │ │ │ │ +│ └─────────────────────┘ ▼ │ │ │ +│ Design Space │ │ │ +│ ┌─────────────────────┐ │ │ │ │ +│ Generate │ │ Pretuning │ │ │ │ +│ ┌───────────┤ Search Strategy │◄─────────────┘ │ │ │ +│ │ │ │ │ ├──┘ +│ │ └─────────────────────┘ ├──┘ +└────┼─────────────────────────────────────────────────────────┘ + │ + │ +┌────┼──────────────── Managed By Task Scheduler ─────────────────────┐ +│ │ ┌───────────┐ │ +│ │ Send to │ │ Send to │ +│ ▼ ┌─────────────►│ Builder ├──────────┐ │ +│ Measure Candidate │ Builder │ │ Runner │ │ +│ │ │ └───────────┘ │ │ +│ │ ┌────────────┴────────┐ │ │ +│ │ │ │ ┌───────────┐ │ │ +│ └────►│ Task Scheduler │ │ │ │ │ +│ │ │ │ Runner │◄─────────┘ │ +│ └─────────────────────┘ │ │ │ +│ ▲ └─────┬─────┘ │ +│ │ │ │ +│ └─── Runner Future ◄────┘ │ +└─────────────────────────────────────────────────────────────────────┘ +*/ class SpaceGeneratorNode : public Object { public: /*! \brief Default destructor */ @@ -37,6 +72,7 @@ class SpaceGeneratorNode : public Object { /*! * \brief Initialize the design space generator with tuning context. * \param tune_context The tuning context for initialization. + * \note This method is supposed to be called only once before every other method. */ virtual void InitializeWithTuneContext(const TuneContext& tune_context) = 0; diff --git a/include/tvm/support/random_engine.h b/include/tvm/support/random_engine.h index 6b733d074f6a..fcd2326050ed 100644 --- a/include/tvm/support/random_engine.h +++ b/include/tvm/support/random_engine.h @@ -102,6 +102,16 @@ class LinearCongruentialEngine { *rand_state_ptr_ = rand_state; // Change pointed random state to given random state value. } + /*! + * \brief Fork a new seed for another RNG from current random state. + * \return The forked seed. + */ + TRandState ForkSeed() { + // In order for reproducibility, we computer the new seed using RNG's random state and a + // different set of parameters. Note that both 32767 and 1999999973 are prime numbers. + return ((*this)() * 32767) % 1999999973; + } + /*! * \brief Construct a random number generator with a random state pointer. * \param rand_state_ptr The random state pointer given in result_type*. diff --git a/python/tvm/meta_schedule/__init__.py b/python/tvm/meta_schedule/__init__.py index f8b2b026c83b..c22cc205bf35 100644 --- a/python/tvm/meta_schedule/__init__.py +++ b/python/tvm/meta_schedule/__init__.py @@ -19,5 +19,7 @@ from . import builder from . import database from . import space_generator +from . import search_strategy +from . import runner from .database import TuningRecord from .tune_context import TuneContext diff --git a/python/tvm/meta_schedule/runner/__init__.py b/python/tvm/meta_schedule/runner/__init__.py new file mode 100644 index 000000000000..65d2ef04e04c --- /dev/null +++ b/python/tvm/meta_schedule/runner/__init__.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""meta_schedule.runner""" +from .runner import RunnerResult diff --git a/python/tvm/meta_schedule/runner/runner.py b/python/tvm/meta_schedule/runner/runner.py new file mode 100644 index 000000000000..b756c6e6b011 --- /dev/null +++ b/python/tvm/meta_schedule/runner/runner.py @@ -0,0 +1,59 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Runners""" +from typing import List, Optional + +from tvm._ffi import register_object +from tvm.runtime import Object + +from .. import _ffi_api + + +@register_object("meta_schedule.RunnerResult") +class RunnerResult(Object): + """The runner's result + + Parameters + ---------- + run_secs : Optional[List[float]] + The run time in seconds. + error_msg : Optional[str] + The error message, if any. + """ + + run_secs: Optional[List[float]] + error_msg: Optional[str] + + def __init__( + self, + run_secs: Optional[List[float]], + error_msg: Optional[str], + ) -> None: + """Constructor + + Parameters + ---------- + run_secs : Optional[List[float]] + The run time in seconds. + error_msg : Optional[str] + The error message, if any. + """ + self.__init_handle_by_constructor__( + _ffi_api.RunnerResult, # type: ignore # pylint: disable=no-member + run_secs, + error_msg, + ) diff --git a/python/tvm/meta_schedule/search_strategy/__init__.py b/python/tvm/meta_schedule/search_strategy/__init__.py new file mode 100644 index 000000000000..40f21da0b2d1 --- /dev/null +++ b/python/tvm/meta_schedule/search_strategy/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Search Strategy""" + +from .search_strategy import SearchStrategy, PySearchStrategy +from .replay_trace import ReplayTrace diff --git a/python/tvm/meta_schedule/search_strategy/replay_trace.py b/python/tvm/meta_schedule/search_strategy/replay_trace.py new file mode 100644 index 000000000000..3afdff6de77e --- /dev/null +++ b/python/tvm/meta_schedule/search_strategy/replay_trace.py @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Replay Trace Search Strategy""" + +from tvm._ffi import register_object +from .search_strategy import SearchStrategy +from .. import _ffi_api + + +@register_object("meta_schedule.ReplayTrace") +class ReplayTrace(SearchStrategy): + """ + Replay Trace Search Strategy is a search strategy that always replays the trace by removing its + decisions so that the decisions would be randomly re-generated. + + Parameters + ---------- + num_trials_per_iter : int + Number of trials per iteration. + num_trials_total : int + Total number of trials. + """ + + num_trials_per_iter: int + num_trials_total: int + + def __init__(self, num_trials_per_iter: int, num_trials_total: int): + """Constructor""" + self.__init_handle_by_constructor__( + _ffi_api.ReplayTrace, # pylint: disable=no-member + num_trials_per_iter, + num_trials_total, + ) diff --git a/python/tvm/meta_schedule/search_strategy/search_strategy.py b/python/tvm/meta_schedule/search_strategy/search_strategy.py new file mode 100644 index 000000000000..72713155c41d --- /dev/null +++ b/python/tvm/meta_schedule/search_strategy/search_strategy.py @@ -0,0 +1,166 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Search Strategy""" + +from typing import List, Optional, TYPE_CHECKING + +from tvm._ffi import register_object +from tvm.runtime import Object +from tvm.tir.schedule import Schedule + +from .. import _ffi_api +from ..arg_info import ArgInfo +from ..runner import RunnerResult + +if TYPE_CHECKING: + from ..tune_context import TuneContext + + +@register_object("meta_schedule.MeasureCandidate") +class MeasureCandidate(Object): + """Measure candidate class. + + Parameters + ---------- + sch : Schedule + The schedule to be measured. + args_info : List[ArgInfo] + The argument information. + """ + + sch: Schedule + args_info: List[ArgInfo] + + def __init__(self, sch: Schedule, args_info: List[ArgInfo]) -> None: + """Constructor. + + Parameters + ---------- + sch : Schedule + The schedule to be measured. + args_info : List[ArgInfo] + The argument information. + """ + self.__init_handle_by_constructor__( + _ffi_api.MeasureCandidate, # pylint: disable=no-member + sch, + args_info, + ) + + +@register_object("meta_schedule.SearchStrategy") +class SearchStrategy(Object): + """ + Search strategy is the class that generates the measure candidates. It has to be pre-tuned + before usage and post-tuned after usage. + """ + + def initialize_with_tune_context( + self, + tune_context: "TuneContext", + ) -> None: + """Initialize the search strategy with tuning context. + + Parameters + ---------- + tune_context : TuneContext + The tuning context for initialization. + """ + _ffi_api.SearchStrategyInitializeWithTuneContext( # pylint: disable=no-member + self, tune_context + ) + + def pre_tuning(self, design_spaces: List[Schedule]) -> None: + """Pre-tuning for the search strategy. + + Parameters + ---------- + design_spaces : List[Schedule] + The design spaces for pre-tuning. + """ + _ffi_api.SearchStrategyPreTuning(self, design_spaces) # pylint: disable=no-member + + def post_tuning(self) -> None: + """Post-tuning for the search strategy.""" + _ffi_api.SearchStrategyPostTuning(self) # pylint: disable=no-member + + def generate_measure_candidates(self) -> Optional[List[MeasureCandidate]]: + """Generate measure candidates from design spaces for measurement. + + Returns + ------- + measure_candidates : Optional[List[IRModule]] + The measure candidates generated, None if finished. + """ + return _ffi_api.SearchStrategyGenerateMeasureCandidates(self) # pylint: disable=no-member + + def notify_runner_results(self, results: List[RunnerResult]) -> None: + """Update the search strategy with profiling results. + + Parameters + ---------- + results : List[RunnerResult] + The profiling results from the runner. + """ + _ffi_api.SearchStrategyNotifyRunnerResults(self, results) # pylint: disable=no-member + + +@register_object("meta_schedule.PySearchStrategy") +class PySearchStrategy(SearchStrategy): + """An abstract search strategy with customized methods on the python-side.""" + + def __init__(self): + """Constructor.""" + + def f_initialize_with_tune_context(context: "TuneContext") -> None: + self.initialize_with_tune_context(context) + + def f_pre_tuning(design_spaces: List[Schedule]) -> None: + self.pre_tuning(design_spaces) + + def f_post_tuning() -> None: + self.post_tuning() + + def f_generate_measure_candidates() -> List[MeasureCandidate]: + return self.generate_measure_candidates() + + def f_notify_runner_results(results: List["RunnerResult"]) -> None: + self.notify_runner_results(results) + + self.__init_handle_by_constructor__( + _ffi_api.SearchStrategyPySearchStrategy, # pylint: disable=no-member + f_initialize_with_tune_context, + f_pre_tuning, + f_post_tuning, + f_generate_measure_candidates, + f_notify_runner_results, + ) + + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + raise NotImplementedError + + def pre_tuning(self, design_spaces: List[Schedule]) -> None: + raise NotImplementedError + + def post_tuning(self) -> None: + raise NotImplementedError + + def generate_measure_candidates(self) -> List[MeasureCandidate]: + raise NotImplementedError + + def notify_runner_results(self, results: List["RunnerResult"]) -> None: + raise NotImplementedError diff --git a/src/meta_schedule/runner/runner.cc b/src/meta_schedule/runner/runner.cc new file mode 100644 index 000000000000..8f509bdd7b84 --- /dev/null +++ b/src/meta_schedule/runner/runner.cc @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +RunnerResult::RunnerResult(Optional> run_secs, Optional error_msg) { + ObjectPtr n = make_object(); + n->run_secs = run_secs; + n->error_msg = error_msg; + this->data_ = n; +} + +TVM_REGISTER_NODE_TYPE(RunnerResultNode); + +TVM_REGISTER_GLOBAL("meta_schedule.RunnerResult") + .set_body_typed([](Array run_secs, Optional error_msg) -> RunnerResult { + return RunnerResult(run_secs, error_msg); + }); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/search_strategy/replay_trace.cc b/src/meta_schedule/search_strategy/replay_trace.cc new file mode 100644 index 000000000000..1c83aee8c0fd --- /dev/null +++ b/src/meta_schedule/search_strategy/replay_trace.cc @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +/*! \brief A search strategy that generates measure candidates using trace and random decisions. */ +class ReplayTraceNode : public SearchStrategyNode { + public: + using TRandState = support::LinearCongruentialEngine::TRandState; + + /*! \brief The state of the search strategy. */ + struct State { + /*! \brief The search strategy itself */ + ReplayTraceNode* self; + /*! \brief The design spaces. */ + Array design_spaces; + /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ + int st; + /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ + int ed; + + explicit State(ReplayTraceNode* self, Array design_spaces) + : self(self), design_spaces(design_spaces), st(0), ed(self->num_trials_per_iter) {} + + inline Optional> GenerateMeasureCandidates(); + inline void NotifyRunnerResults(const Array& results); + }; + + /*! \brief The number of trials per iteration. */ + int num_trials_per_iter; + /*! \brief The number of total trials. */ + int num_trials_total; + + /*! \brief The module to be tuned. */ + IRModule mod_{nullptr}; + /*! \brief The metadata of the function arguments. */ + Array args_info_{nullptr}; + /*! \brief The number of threads to use. -1 means using logical cpu number. */ + int num_threads_ = -1; + /*! \brief The random state. -1 means using random number. */ + TRandState rand_state_ = -1; + /*! \brief The state of the search strategy. */ + std::unique_ptr state_ = nullptr; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("num_trials_per_iter", &num_trials_per_iter); + v->Visit("num_trials_total", &num_trials_total); + // `mod_` is not visited + // `args_info_` is not visited + // `num_threads_` is not visited + // `rand_state_` is not visited + // `state_` is not visited + } + + static constexpr const char* _type_key = "meta_schedule.ReplayTrace"; + TVM_DECLARE_FINAL_OBJECT_INFO(ReplayTraceNode, SearchStrategyNode); + + void InitializeWithTuneContext(const TuneContext& tune_context) final { + this->mod_ = tune_context->mod.value(); + this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(this->mod_)); + this->num_threads_ = tune_context->num_threads; + this->rand_state_ = ForkSeed(&tune_context->rand_state); + this->state_.reset(); + } + + void PreTuning(const Array& design_spaces) final { + ICHECK(!design_spaces.empty()); + ICHECK(this->state_ == nullptr); + this->state_ = std::make_unique(this, design_spaces); + } + + void PostTuning() final { + ICHECK(this->state_ != nullptr); + this->state_.reset(); + } + + Optional> GenerateMeasureCandidates() final { + ICHECK(this->state_ != nullptr); + return this->state_->GenerateMeasureCandidates(); + } + + void NotifyRunnerResults(const Array& results) final { + ICHECK(this->state_ != nullptr); + this->state_->NotifyRunnerResults(results); + } +}; + +inline Optional> ReplayTraceNode::State::GenerateMeasureCandidates() { + if (st >= self->num_trials_total) { + return NullOpt; + } + ed = std::min(ed, self->num_trials_total); + ICHECK_LT(st, ed); + std::vector per_thread_rand_state = ForkSeed(&self->rand_state_, self->num_threads_); + Array per_task_result(ed - st, MeasureCandidate{nullptr}); + auto f_worker = [this, &per_thread_rand_state, &per_task_result](int thread_id, + int task_id) -> void { + TRandState& rand_state = per_thread_rand_state[thread_id]; + int design_space_index = tir::SampleInt(&rand_state, 0, design_spaces.size()); + tir::Trace trace = design_spaces[design_space_index]->trace().value(); + tir::Trace new_trace = tir::Trace(trace->insts, {}); + tir::Schedule sch = tir::Schedule::Traced( // + self->mod_, // + /*rand_state=*/ForkSeed(&rand_state), // + /*debug_mode=*/0, // + /*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone); + new_trace->ApplyToSchedule(sch, /*remove_postproc=*/true); + per_task_result.Set(task_id, MeasureCandidate(sch, self->args_info_)); + }; + support::parallel_for_dynamic(0, ed - st, self->num_threads_, f_worker); + return per_task_result; +} + +inline void ReplayTraceNode::State::NotifyRunnerResults(const Array& results) { + st += self->num_trials_per_iter; + ed += self->num_trials_per_iter; +} + +SearchStrategy SearchStrategy::ReplayTrace(int num_trials_per_iter, int num_trials_total) { + ObjectPtr n = make_object(); + n->num_trials_per_iter = num_trials_per_iter; + n->num_trials_total = num_trials_total; + return SearchStrategy(n); +} + +TVM_REGISTER_NODE_TYPE(ReplayTraceNode); +TVM_REGISTER_GLOBAL("meta_schedule.ReplayTrace").set_body_typed(SearchStrategy::ReplayTrace); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/search_strategy/search_strategy.cc b/src/meta_schedule/search_strategy/search_strategy.cc new file mode 100644 index 000000000000..fefe8dfce76e --- /dev/null +++ b/src/meta_schedule/search_strategy/search_strategy.cc @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +MeasureCandidate::MeasureCandidate(tir::Schedule sch, Array args_info) { + ObjectPtr n = make_object(); + n->sch = sch; + n->args_info = args_info; + data_ = std::move(n); +} + +SearchStrategy SearchStrategy::PySearchStrategy( + PySearchStrategyNode::FInitializeWithTuneContext f_initialize_with_tune_context, // + PySearchStrategyNode::FPreTuning f_pre_tuning, // + PySearchStrategyNode::FPostTuning f_post_tuning, // + PySearchStrategyNode::FGenerateMeasureCandidates f_generate_measure_candidates, // + PySearchStrategyNode::FNotifyRunnerResults f_notify_runner_results) { + ObjectPtr n = make_object(); + n->f_initialize_with_tune_context = f_initialize_with_tune_context; + n->f_pre_tuning = f_pre_tuning; + n->f_post_tuning = f_post_tuning; + n->f_generate_measure_candidates = f_generate_measure_candidates; + n->f_notify_runner_results = f_notify_runner_results; + return SearchStrategy(n); +} + +TVM_REGISTER_NODE_TYPE(MeasureCandidateNode); +TVM_REGISTER_OBJECT_TYPE(SearchStrategyNode); +TVM_REGISTER_NODE_TYPE(PySearchStrategyNode); + +TVM_REGISTER_GLOBAL("meta_schedule.MeasureCandidate") + .set_body_typed([](tir::Schedule sch, Array args_info) -> MeasureCandidate { + return MeasureCandidate(sch, args_info); + }); +TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyPySearchStrategy") + .set_body_typed(SearchStrategy::PySearchStrategy); +TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyInitializeWithTuneContext") + .set_body_method(&SearchStrategyNode::InitializeWithTuneContext); +TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyPreTuning") + .set_body_method(&SearchStrategyNode::PreTuning); +TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyPostTuning") + .set_body_method(&SearchStrategyNode::PostTuning); +TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyGenerateMeasureCandidates") + .set_body_method(&SearchStrategyNode::GenerateMeasureCandidates); +TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyNotifyRunnerResults") + .set_body_method(&SearchStrategyNode::NotifyRunnerResults); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index 4c9e1e2c10a1..30294b8f91e1 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -23,16 +23,22 @@ #include #include #include +#include +#include #include #include #include #include +#include #include #include +#include +#include "../printer/text_printer.h" #include "../support/array.h" #include "../support/base64.h" +#include "../tir/schedule/primitive.h" namespace tvm { namespace meta_schedule { @@ -131,6 +137,76 @@ inline String JSONObj2Str(const ObjectRef& json_obj) { */ inline String SHash2Str(Workload::THashCode hash_code) { return std::to_string(hash_code); } +/*! + * \brief Find the entry function of the given IRModule, i.e, functions marked by + * `tir::attr::kIsEntryFunc`, whose name is `main` or being the only PrimeFunc. + * \param mod The IRModule to find the entry function. + * \return The entry function. + */ +inline tir::PrimFunc FindEntryFunc(const IRModule& mod) { + // Priority 1: PrimFunc marked as `tir::attr::kIsEntryFunc` + int num_prim_func = 0; + const tir::PrimFuncNode* main_func = nullptr; + const tir::PrimFuncNode* last_func = nullptr; + for (const auto& kv : mod->functions) { + GlobalVar gv = kv.first; + BaseFunc base_func = kv.second; + if (const auto* func = base_func.as()) { + last_func = func; + if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { + return GetRef(func); + } + if (gv->name_hint == "main") { + main_func = func; + } + ++num_prim_func; + } + } + // Priority 2: PrimFunc whose name is `main` + if (main_func != nullptr) { + return GetRef(main_func); + } + // Priority 3: The only PrimFunc in the IRModule + if (num_prim_func == 0) { + LOG(FATAL) << "ValueError: Cannot find any PrimFunc in the given IRModule: " + << tir::AsTVMScript(mod); + } + if (num_prim_func > 1) { + LOG(FATAL) << "ValueError: Multiple PrimFuncs exist in the IRModule, but none of them are " + "annotated with `kIsEntryFunc`, i.e. `tir.is_entry_func`" + << tir::AsTVMScript(mod); + } + return GetRef(last_func); +} + +/*! + * \brief Fork a random state into another, i.e. PRNG splitting. + * The given random state is also mutated. + * \param rand_state The random state to be forked + * \return The forked random state + */ +inline support::LinearCongruentialEngine::TRandState ForkSeed( + support::LinearCongruentialEngine::TRandState* rand_state) { + return support::LinearCongruentialEngine(rand_state).ForkSeed(); +} + +/*! + * \brief Fork a random state into another ones, i.e. PRNG splitting. + * The given random state is also mutated. + * \param rand_state The random state to be forked + * \param n The number of forks + * \return The forked random states + */ +inline std::vector ForkSeed( + support::LinearCongruentialEngine::TRandState* rand_state, int n) { + std::vector results; + results.reserve(n); + for (int i = 0; i < n; ++i) { + results.push_back(support::LinearCongruentialEngine(rand_state).ForkSeed()); + } + return results; +} + } // namespace meta_schedule } // namespace tvm diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 07af73ebabb6..93eba520f9d1 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -220,9 +220,7 @@ void ConcreteScheduleNode::Seed(support::LinearCongruentialEngine::TRandState se } support::LinearCongruentialEngine::TRandState ConcreteScheduleNode::ForkSeed() { - // In order for reproducibility, we computer the new seed using RNG's random state and a different - // set of parameters. Note that both 32767 and 1999999973 are prime numbers. - return (support::LinearCongruentialEngine(&rand_state_)() * 32767) % 1999999973; + return support::LinearCongruentialEngine(&rand_state_).ForkSeed(); } ExprRV ConcreteScheduleNode::SampleCategorical(const Array& candidates, diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 8ad6bdf7d37f..8d8acd2693f4 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -26,6 +26,14 @@ namespace tvm { namespace tir { /******** Schedule: Sampling ********/ +/*! + * \brief Sample a random integer from a given range. + * \param min_inclusive The minimum value of the range, inclusive. + * \param max_exclusive The maximum value of the range, exclusive. + * \return The random integer sampled in the given range. + */ +TVM_DLL int SampleInt(support::LinearCongruentialEngine::TRandState* rand_state, int min_inclusive, + int max_exclusive); /*! * \brief Sample once category from candidates according to the probability weights. * \param self The schedule to update diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index 8843ac613179..6ac6226118cd 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -24,6 +24,18 @@ namespace tvm { namespace tir { +int SampleInt(support::LinearCongruentialEngine::TRandState* rand_state, int min_inclusive, + int max_exclusive) { + CHECK(min_inclusive < max_exclusive) + << "ValueError: max_exclusive must be greater than min_inclusive."; + if (min_inclusive + 1 == max_exclusive) { + return min_inclusive; + } + support::LinearCongruentialEngine rand_(rand_state); + std::uniform_int_distribution dist(min_inclusive, max_exclusive - 1); + return dist(rand_); +} + int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_state, const Array& candidates, const Array& probs, Optional* decision) { diff --git a/tests/python/unittest/test_meta_schedule_search_strategy.py b/tests/python/unittest/test_meta_schedule_search_strategy.py new file mode 100644 index 000000000000..6e90bddb84b4 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_search_strategy.py @@ -0,0 +1,98 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" Test Meta Schedule SearchStrategy """ +# pylint: disable=missing-function-docstring +from typing import List + +import sys + +import pytest + +import tvm +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.runner import RunnerResult +from tvm.meta_schedule.space_generator import ScheduleFn +from tvm.meta_schedule.search_strategy import SearchStrategy, ReplayTrace + +from tvm.script import ty +from tvm.tir.schedule import Schedule, Trace + + +MATMUL_M = 32 + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument, unbalanced-tuple-unpacking +# fmt: off + +@tvm.script.tir +class Matmul: + def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + tir.func_attr({"global_symbol": "main"}) + A = tir.match_buffer(a, (32, 32), "float32") + B = tir.match_buffer(b, (32, 32), "float32") + C = tir.match_buffer(c, (32, 32), "float32") + with tir.block([32, 32, tir.reduce_axis(0, 32)], "matmul") as [vi, vj, vk]: + with tir.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument + + +def _is_trace_equal(sch_1: Schedule, sch_2: Schedule) -> bool: + trace_1 = Trace(sch_1.trace.insts, {}) + trace_2 = Trace(sch_2.trace.insts, {}) + return str(trace_1) == str(trace_2) + + +def _schedule_matmul(sch: Schedule): + block = sch.get_block("matmul") + i, j, k = sch.get_loops(block=block) + # TODO(@zxybazh): Change to `sample_perfect_tile` after upstreaming + i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=[2, 4, 64, 2]) + j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=[4, 64, 2, 2]) + k_0, k_1 = sch.split(loop=k, factors=[32, 32]) + sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3) + + +def test_meta_schedule_replay_trace(): + num_trials_per_iter = 7 + num_trials_total = 20 + + (example_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul()) + replay = ReplayTrace(num_trials_per_iter=num_trials_per_iter, num_trials_total=num_trials_total) + tune_context = TuneContext(mod=Matmul()) + replay.initialize_with_tune_context(tune_context) + + num_trials_each_round: List[int] = [] + replay.pre_tuning([example_sch]) + while True: + candidates = replay.generate_measure_candidates() + if candidates is None: + break + num_trials_each_round.append(len(candidates)) + runner_results: List[RunnerResult] = [] + for candidate in candidates: + assert _is_trace_equal(candidate.sch, example_sch) + runner_results.append(RunnerResult(run_secs=[0.5, 0.4, 0.3], error_msg=None)) + replay.notify_runner_results(runner_results) + replay.post_tuning() + assert num_trials_each_round == [7, 7, 6] + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:]))