diff --git a/include/tvm/auto_scheduler/measure.h b/include/tvm/auto_scheduler/measure.h index e8c01e84f289..841b6b953087 100755 --- a/include/tvm/auto_scheduler/measure.h +++ b/include/tvm/auto_scheduler/measure.h @@ -232,6 +232,35 @@ class MeasureCallback : public ObjectRef { TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MeasureCallback, ObjectRef, MeasureCallbackNode); }; +/*! \brief A wrapper for measure callback defined by python code + * This class will call functions defined in the python */ +class PythonBasedMeasureCallbackNode : public MeasureCallbackNode { + public: + /*! \brief Pointer to the callback funcion in python */ + PackedFunc callback_func; + + void Callback(const SearchPolicy& policy, const Array& inputs, + const Array& results) final; + static constexpr const char* _type_key = "auto_scheduler.PythonBasedMeasureCallback"; + TVM_DECLARE_FINAL_OBJECT_INFO(PythonBasedMeasureCallbackNode, MeasureCallbackNode); +}; + +/*! + * \brief Managed reference to PythonBasedMeasureCallbackNode. + * \sa PythonBasedMeasureCallbackNode + */ +class PythonBasedMeasureCallback : public MeasureCallback { + public: + /*! + * \brief The constructor. + * \param callback_func The pointer to the callback function defined in python + */ + explicit PythonBasedMeasureCallback(PackedFunc callback_func); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PythonBasedMeasureCallback, MeasureCallback, + PythonBasedMeasureCallbackNode); +}; + // The base class of ProgramBuilders and ProgramRunners. /*! \brief ProgramBuilder that builds the programs */ diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index 7e4f14933819..38a420df9d91 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -70,6 +70,31 @@ class MeasureCallback(Object): """ The base class of measurement callback functions. """ +@tvm._ffi.register_object("auto_scheduler.PythonBasedMeasureCallback") +class PythonBasedMeasureCallback(MeasureCallback): + """Base class for measure callbacks implemented in python""" + + def __init__(self): + def callback_func(policy, inputs, results): + self.callback(policy, inputs, results) + + self.__init_handle_by_constructor__(_ffi_api.PythonBasedMeasureCallback, callback_func) + + def callback(self, policy, inputs, results): + """The callback function. + + Parameters + ---------- + policy: auto_scheduler.search_policy.SearchPolicy + The search policy. + inputs : List[auto_scheduler.measure.MeasureInput] + The measurement inputs + results : List[auto_scheduler.measure.MeasureResult] + The measurement results + """ + raise NotImplementedError + + @tvm._ffi.register_object("auto_scheduler.MeasureInput") class MeasureInput(Object): """Store the input of a measurement. diff --git a/src/auto_scheduler/measure.cc b/src/auto_scheduler/measure.cc index 5b7e886f073c..c3212f2b4478 100755 --- a/src/auto_scheduler/measure.cc +++ b/src/auto_scheduler/measure.cc @@ -27,6 +27,8 @@ #include +#include "search_policy/empty_policy.h" +#include "search_policy/sketch_policy.h" #include "utils.h" namespace tvm { @@ -36,6 +38,7 @@ TVM_REGISTER_NODE_TYPE(MeasureInputNode); TVM_REGISTER_NODE_TYPE(BuildResultNode); TVM_REGISTER_NODE_TYPE(MeasureResultNode); TVM_REGISTER_OBJECT_TYPE(MeasureCallbackNode); +TVM_REGISTER_OBJECT_TYPE(PythonBasedMeasureCallbackNode); TVM_REGISTER_OBJECT_TYPE(ProgramRunnerNode); TVM_REGISTER_OBJECT_TYPE(ProgramBuilderNode); TVM_REGISTER_OBJECT_TYPE(ProgramMeasurerNode); @@ -183,6 +186,25 @@ Array RPCRunnerNode::Run(const Array& inputs, return Array(); } +/********** MeasureCallback **********/ +PythonBasedMeasureCallback::PythonBasedMeasureCallback(PackedFunc callback_func) { + auto node = make_object(); + node->callback_func = std::move(callback_func); + data_ = std::move(node); +} + +void PythonBasedMeasureCallbackNode::Callback(const SearchPolicy& policy, + const Array& inputs, + const Array& results) { + if (auto* sketch_policy = static_cast(policy.operator->())) { + callback_func(GetRef(sketch_policy), inputs, results); + } else if (auto* empty_policy = static_cast(policy.operator->())) { + callback_func(GetRef(empty_policy), inputs, results); + } else { + LOG(FATAL) << "Unrecognized search policy type. Expect SketchPolicy or EmptyPolicy"; + } +} + /********** ProgramMeasurer **********/ ProgramMeasurer::ProgramMeasurer(ProgramBuilder builder, ProgramRunner runner, Optional> callbacks, int verbose, @@ -360,6 +382,11 @@ TVM_REGISTER_GLOBAL("auto_scheduler.MeasureResult") return MeasureResult(costs, error_no, error_msg, all_cost, timestamp); }); +TVM_REGISTER_GLOBAL("auto_scheduler.PythonBasedMeasureCallback") + .set_body_typed([](PackedFunc callback_func) { + return PythonBasedMeasureCallback(callback_func); + }); + TVM_REGISTER_GLOBAL("auto_scheduler.ProgramMeasurer") .set_body_typed([](ProgramBuilder builder, ProgramRunner runner, Array callbacks, int verbose, int max_continuous_error) { diff --git a/tests/python/unittest/test_auto_scheduler_search_policy.py b/tests/python/unittest/test_auto_scheduler_search_policy.py index 1bb74497898c..6d4fb6884ff9 100644 --- a/tests/python/unittest/test_auto_scheduler_search_policy.py +++ b/tests/python/unittest/test_auto_scheduler_search_policy.py @@ -30,6 +30,16 @@ import multiprocessing +class CustomMeasureCallback(auto_scheduler.measure.PythonBasedMeasureCallback): + """A simple Python-based callback for testing.""" + + def callback(self, policy, inputs, results): + assert isinstance(policy, auto_scheduler.search_policy.SearchPolicy) + for inp, res in zip(inputs, results): + assert isinstance(inp, auto_scheduler.MeasureInput) + assert isinstance(res, auto_scheduler.MeasureResult) + + def search_common( workload=matmul_auto_scheduler_test, target="llvm", @@ -68,7 +78,7 @@ def search_common( early_stopping=1, runner=runner, verbose=2, - measure_callbacks=[auto_scheduler.RecordToFile(log_file)], + measure_callbacks=[auto_scheduler.RecordToFile(log_file), CustomMeasureCallback()], ) task.tune(tuning_options=tuning_options, search_policy=search_policy) sch, args = task.apply_best(log_file)