From 29278f9facd1609d7addd3282888967b4b1325ee Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Mon, 21 Dec 2020 21:10:55 +0000 Subject: [PATCH 1/9] add --- include/tvm/auto_scheduler/measure.h | 27 +++++++++++++++++++ python/tvm/auto_scheduler/measure.py | 24 +++++++++++++++++ src/auto_scheduler/measure.cc | 6 +++++ .../relay/test_auto_scheduler_tuning.py | 7 ++++- 4 files changed, 63 insertions(+), 1 deletion(-) diff --git a/include/tvm/auto_scheduler/measure.h b/include/tvm/auto_scheduler/measure.h index e8c01e84f289..213ea6ccf158 100755 --- a/include/tvm/auto_scheduler/measure.h +++ b/include/tvm/auto_scheduler/measure.h @@ -232,6 +232,33 @@ class MeasureCallback : public ObjectRef { TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MeasureCallback, ObjectRef, MeasureCallbackNode); }; +/*! \brief A wrapper for measure callback defined by python code + * This class will call functions defined in the python */ +class PythonBasedMeasureCallbackNode : public MeasureCallbackNode { + public: + /*! \brief Pointer to the callback funcion in python */ + PackedFunc callback_func; + + static constexpr const char* _type_key = "auto_scheduler.PythonBasedMeasureCallback"; + TVM_DECLARE_FINAL_OBJECT_INFO(PythonBasedMeasureCallbackNode, MeasureCallbackNode); +}; + +/*! + * \brief Managed reference to PythonBasedMeasureCallbackNode. + * \sa PythonBasedMeasureCallbackNode + */ +class PythonBasedMeasureCallback : public MeasureCallback { + public: + /*! + * \brief The constructor. + * \param callback The pointer to the callback function defined in python + */ + PythonBasedMeasureCallback(PackedFunc callback); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PythonBasedMeasureCallback, MeasureCallback, + PythonBasedMeasureCallbackNode); +}; + // The base class of ProgramBuilders and ProgramRunners. /*! \brief ProgramBuilder that builds the programs */ diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index 7e4f14933819..6bb6cccc344a 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -70,6 +70,30 @@ class MeasureCallback(Object): """ The base class of measurement callback functions. """ +@tvm._ffi.register_object("auto_scheduler.PythonBasedMeasureCallback") +class PythonBasedMeasureCallback(MeasureCallback): + """Base class for measure callbacks implemented in python""" + + def __init__(self): + def callback_func(policy, inputs, results): + self.callback(policy, inputs, results) + + self.__init_handle_by_constructor__(_ffi_api.PythonBasedMeasureCallback, callback_func) + + def callback(self, policy, inputs, results): + """Update the cost model according to new measurement results (training data). + + Parameters + ---------- + policy: SearchPolicy + The search policy. + inputs : List[auto_scheduler.measure.MeasureInput] + The measurement inputs + results : List[auto_scheduler.measure.MeasureResult] + The measurement results + """ + raise NotImplementedError + @tvm._ffi.register_object("auto_scheduler.MeasureInput") class MeasureInput(Object): """Store the input of a measurement. diff --git a/src/auto_scheduler/measure.cc b/src/auto_scheduler/measure.cc index 5b7e886f073c..1f333ec5280e 100755 --- a/src/auto_scheduler/measure.cc +++ b/src/auto_scheduler/measure.cc @@ -36,6 +36,7 @@ TVM_REGISTER_NODE_TYPE(MeasureInputNode); TVM_REGISTER_NODE_TYPE(BuildResultNode); TVM_REGISTER_NODE_TYPE(MeasureResultNode); TVM_REGISTER_OBJECT_TYPE(MeasureCallbackNode); +TVM_REGISTER_OBJECT_TYPE(PythonBasedMeasureCallbackNode); TVM_REGISTER_OBJECT_TYPE(ProgramRunnerNode); TVM_REGISTER_OBJECT_TYPE(ProgramBuilderNode); TVM_REGISTER_OBJECT_TYPE(ProgramMeasurerNode); @@ -360,6 +361,11 @@ TVM_REGISTER_GLOBAL("auto_scheduler.MeasureResult") return MeasureResult(costs, error_no, error_msg, all_cost, timestamp); }); +TVM_REGISTER_GLOBAL("auto_scheduler.PythonBasedMeasureCallback") + .set_body_typed([](PackedFunc callback_func) { + return PythonBasedMeasureCallback(callback_func); + }); + TVM_REGISTER_GLOBAL("auto_scheduler.ProgramMeasurer") .set_body_typed([](ProgramBuilder builder, ProgramRunner runner, Array callbacks, int verbose, int max_continuous_error) { diff --git a/tests/python/relay/test_auto_scheduler_tuning.py b/tests/python/relay/test_auto_scheduler_tuning.py index d42373c86626..226e73c80b87 100644 --- a/tests/python/relay/test_auto_scheduler_tuning.py +++ b/tests/python/relay/test_auto_scheduler_tuning.py @@ -22,6 +22,11 @@ from test_auto_scheduler_task_extraction import get_network +class CustomMeasureCallback(auto_scheduler.measure.PythonBasedMeasureCallback): + """A simple Python-based callback for testing.""" + def callback(self, policy, inputs, results): + for inp, res in zip(inputs, results): + print(inp, res) def tune_network(network, target): # Extract tasks @@ -41,7 +46,7 @@ def tune_network(network, target): early_stopping=1, runner=measure_ctx.runner, builder=auto_scheduler.LocalBuilder(timeout=60), - measure_callbacks=[auto_scheduler.RecordToFile(log_file)], + measure_callbacks=[auto_scheduler.RecordToFile(log_file), CustomMeasureCallback()], ) tuner.tune(tune_option, search_policy="sketch.random") del measure_ctx From a9faf51f2d499a74e5af684d232da2510345e6c1 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Mon, 21 Dec 2020 21:42:47 +0000 Subject: [PATCH 2/9] make it work --- include/tvm/auto_scheduler/measure.h | 6 ++++-- python/tvm/auto_scheduler/measure.py | 8 +++----- src/auto_scheduler/measure.cc | 13 +++++++++++++ tests/python/relay/test_auto_scheduler_tuning.py | 2 +- 4 files changed, 21 insertions(+), 8 deletions(-) diff --git a/include/tvm/auto_scheduler/measure.h b/include/tvm/auto_scheduler/measure.h index 213ea6ccf158..841b6b953087 100755 --- a/include/tvm/auto_scheduler/measure.h +++ b/include/tvm/auto_scheduler/measure.h @@ -239,6 +239,8 @@ class PythonBasedMeasureCallbackNode : public MeasureCallbackNode { /*! \brief Pointer to the callback funcion in python */ PackedFunc callback_func; + void Callback(const SearchPolicy& policy, const Array& inputs, + const Array& results) final; static constexpr const char* _type_key = "auto_scheduler.PythonBasedMeasureCallback"; TVM_DECLARE_FINAL_OBJECT_INFO(PythonBasedMeasureCallbackNode, MeasureCallbackNode); }; @@ -251,9 +253,9 @@ class PythonBasedMeasureCallback : public MeasureCallback { public: /*! * \brief The constructor. - * \param callback The pointer to the callback function defined in python + * \param callback_func The pointer to the callback function defined in python */ - PythonBasedMeasureCallback(PackedFunc callback); + explicit PythonBasedMeasureCallback(PackedFunc callback_func); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PythonBasedMeasureCallback, MeasureCallback, PythonBasedMeasureCallbackNode); diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index 6bb6cccc344a..46a19a07ea7e 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -75,18 +75,16 @@ class PythonBasedMeasureCallback(MeasureCallback): """Base class for measure callbacks implemented in python""" def __init__(self): - def callback_func(policy, inputs, results): - self.callback(policy, inputs, results) + def callback_func(inputs, results): + self.callback(inputs, results) self.__init_handle_by_constructor__(_ffi_api.PythonBasedMeasureCallback, callback_func) - def callback(self, policy, inputs, results): + def callback(self, inputs, results): """Update the cost model according to new measurement results (training data). Parameters ---------- - policy: SearchPolicy - The search policy. inputs : List[auto_scheduler.measure.MeasureInput] The measurement inputs results : List[auto_scheduler.measure.MeasureResult] diff --git a/src/auto_scheduler/measure.cc b/src/auto_scheduler/measure.cc index 1f333ec5280e..aa203949899f 100755 --- a/src/auto_scheduler/measure.cc +++ b/src/auto_scheduler/measure.cc @@ -184,6 +184,19 @@ Array RPCRunnerNode::Run(const Array& inputs, return Array(); } +/********** MeasureCallback **********/ +PythonBasedMeasureCallback::PythonBasedMeasureCallback(PackedFunc callback_func) { + auto node = make_object(); + node->callback_func = std::move(callback_func); + data_ = std::move(node); +} + +void PythonBasedMeasureCallbackNode::Callback(const SearchPolicy& policy, + const Array& inputs, + const Array& results) { + callback_func(inputs, results); +} + /********** ProgramMeasurer **********/ ProgramMeasurer::ProgramMeasurer(ProgramBuilder builder, ProgramRunner runner, Optional> callbacks, int verbose, diff --git a/tests/python/relay/test_auto_scheduler_tuning.py b/tests/python/relay/test_auto_scheduler_tuning.py index 226e73c80b87..9dbb8221b6fb 100644 --- a/tests/python/relay/test_auto_scheduler_tuning.py +++ b/tests/python/relay/test_auto_scheduler_tuning.py @@ -24,7 +24,7 @@ class CustomMeasureCallback(auto_scheduler.measure.PythonBasedMeasureCallback): """A simple Python-based callback for testing.""" - def callback(self, policy, inputs, results): + def callback(self, inputs, results): for inp, res in zip(inputs, results): print(inp, res) From 4b48af2aa5ac4b3cb6d443ebb61dc38a25204fce Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Mon, 21 Dec 2020 22:04:33 +0000 Subject: [PATCH 3/9] format --- python/tvm/auto_scheduler/measure.py | 1 + tests/python/relay/test_auto_scheduler_tuning.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index 46a19a07ea7e..be3650ed886f 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -92,6 +92,7 @@ def callback(self, inputs, results): """ raise NotImplementedError + @tvm._ffi.register_object("auto_scheduler.MeasureInput") class MeasureInput(Object): """Store the input of a measurement. diff --git a/tests/python/relay/test_auto_scheduler_tuning.py b/tests/python/relay/test_auto_scheduler_tuning.py index 9dbb8221b6fb..5207b867ba69 100644 --- a/tests/python/relay/test_auto_scheduler_tuning.py +++ b/tests/python/relay/test_auto_scheduler_tuning.py @@ -22,12 +22,15 @@ from test_auto_scheduler_task_extraction import get_network + class CustomMeasureCallback(auto_scheduler.measure.PythonBasedMeasureCallback): """A simple Python-based callback for testing.""" + def callback(self, inputs, results): for inp, res in zip(inputs, results): print(inp, res) + def tune_network(network, target): # Extract tasks mod, params = get_network(network) From 1d1ad744d43043251a5efc21dac32f75f5e53e53 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Wed, 23 Dec 2020 06:21:31 +0000 Subject: [PATCH 4/9] add poilcy --- python/tvm/auto_scheduler/measure.py | 8 +++++--- src/auto_scheduler/measure.cc | 10 +++++++++- tests/python/relay/test_auto_scheduler_tuning.py | 5 +++-- 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index be3650ed886f..29e02d0b7999 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -75,16 +75,18 @@ class PythonBasedMeasureCallback(MeasureCallback): """Base class for measure callbacks implemented in python""" def __init__(self): - def callback_func(inputs, results): - self.callback(inputs, results) + def callback_func(policy, inputs, results): + self.callback(policy, inputs, results) self.__init_handle_by_constructor__(_ffi_api.PythonBasedMeasureCallback, callback_func) - def callback(self, inputs, results): + def callback(self, policy, inputs, results): """Update the cost model according to new measurement results (training data). Parameters ---------- + policy: auto_scheduler.search_policy.SearchPolicy + The search policy. inputs : List[auto_scheduler.measure.MeasureInput] The measurement inputs results : List[auto_scheduler.measure.MeasureResult] diff --git a/src/auto_scheduler/measure.cc b/src/auto_scheduler/measure.cc index aa203949899f..c3212f2b4478 100755 --- a/src/auto_scheduler/measure.cc +++ b/src/auto_scheduler/measure.cc @@ -27,6 +27,8 @@ #include +#include "search_policy/empty_policy.h" +#include "search_policy/sketch_policy.h" #include "utils.h" namespace tvm { @@ -194,7 +196,13 @@ PythonBasedMeasureCallback::PythonBasedMeasureCallback(PackedFunc callback_func) void PythonBasedMeasureCallbackNode::Callback(const SearchPolicy& policy, const Array& inputs, const Array& results) { - callback_func(inputs, results); + if (auto* sketch_policy = static_cast(policy.operator->())) { + callback_func(GetRef(sketch_policy), inputs, results); + } else if (auto* empty_policy = static_cast(policy.operator->())) { + callback_func(GetRef(empty_policy), inputs, results); + } else { + LOG(FATAL) << "Unrecognized search policy type. Expect SketchPolicy or EmptyPolicy"; + } } /********** ProgramMeasurer **********/ diff --git a/tests/python/relay/test_auto_scheduler_tuning.py b/tests/python/relay/test_auto_scheduler_tuning.py index 5207b867ba69..7b81d54be7c6 100644 --- a/tests/python/relay/test_auto_scheduler_tuning.py +++ b/tests/python/relay/test_auto_scheduler_tuning.py @@ -26,9 +26,10 @@ class CustomMeasureCallback(auto_scheduler.measure.PythonBasedMeasureCallback): """A simple Python-based callback for testing.""" - def callback(self, inputs, results): + def callback(self, policy, inputs, results): + assert isinstance(policy, auto_scheduler.search_policy.SketchPolicy) for inp, res in zip(inputs, results): - print(inp, res) + print(policy, inp, res) def tune_network(network, target): From d60b1dcc18390410a90658e65731855dd7d394aa Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Wed, 23 Dec 2020 06:22:31 +0000 Subject: [PATCH 5/9] comment --- python/tvm/auto_scheduler/measure.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index 29e02d0b7999..38a420df9d91 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -81,7 +81,7 @@ def callback_func(policy, inputs, results): self.__init_handle_by_constructor__(_ffi_api.PythonBasedMeasureCallback, callback_func) def callback(self, policy, inputs, results): - """Update the cost model according to new measurement results (training data). + """The callback function. Parameters ---------- From a7c600eaad2e153e0dd989c54cd4afa1d3d6eecb Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Wed, 23 Dec 2020 17:43:53 +0000 Subject: [PATCH 6/9] move test --- tests/python/relay/test_auto_scheduler_tuning.py | 2 +- .../unittest/test_auto_scheduler_search_policy.py | 10 +++++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/python/relay/test_auto_scheduler_tuning.py b/tests/python/relay/test_auto_scheduler_tuning.py index 7b81d54be7c6..9b6c6f2c263a 100644 --- a/tests/python/relay/test_auto_scheduler_tuning.py +++ b/tests/python/relay/test_auto_scheduler_tuning.py @@ -50,7 +50,7 @@ def tune_network(network, target): early_stopping=1, runner=measure_ctx.runner, builder=auto_scheduler.LocalBuilder(timeout=60), - measure_callbacks=[auto_scheduler.RecordToFile(log_file), CustomMeasureCallback()], + measure_callbacks=[auto_scheduler.RecordToFile(log_file)], ) tuner.tune(tune_option, search_policy="sketch.random") del measure_ctx diff --git a/tests/python/unittest/test_auto_scheduler_search_policy.py b/tests/python/unittest/test_auto_scheduler_search_policy.py index 1bb74497898c..77f04152e493 100644 --- a/tests/python/unittest/test_auto_scheduler_search_policy.py +++ b/tests/python/unittest/test_auto_scheduler_search_policy.py @@ -29,6 +29,14 @@ from test_auto_scheduler_common import matmul_auto_scheduler_test, PropagatingThread import multiprocessing +class CustomMeasureCallback(auto_scheduler.measure.PythonBasedMeasureCallback): + """A simple Python-based callback for testing.""" + + def callback(self, policy, inputs, results): + assert isinstance(policy, auto_scheduler.search_policy.SketchPolicy) + for inp, res in zip(inputs, results): + assert isinstance(inp, auto_scheduler.MeasureInput) + assert isinstance(res, auto_scheduler.MeasureResult) def search_common( workload=matmul_auto_scheduler_test, @@ -68,7 +76,7 @@ def search_common( early_stopping=1, runner=runner, verbose=2, - measure_callbacks=[auto_scheduler.RecordToFile(log_file)], + measure_callbacks=[auto_scheduler.RecordToFile(log_file), CustomMeasureCallback()], ) task.tune(tuning_options=tuning_options, search_policy=search_policy) sch, args = task.apply_best(log_file) From 6736d6df792b0740c62a702ab8646ad8614d05a1 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Wed, 23 Dec 2020 18:17:13 +0000 Subject: [PATCH 7/9] format --- tests/python/unittest/test_auto_scheduler_search_policy.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/python/unittest/test_auto_scheduler_search_policy.py b/tests/python/unittest/test_auto_scheduler_search_policy.py index 77f04152e493..97448bc1984a 100644 --- a/tests/python/unittest/test_auto_scheduler_search_policy.py +++ b/tests/python/unittest/test_auto_scheduler_search_policy.py @@ -29,6 +29,7 @@ from test_auto_scheduler_common import matmul_auto_scheduler_test, PropagatingThread import multiprocessing + class CustomMeasureCallback(auto_scheduler.measure.PythonBasedMeasureCallback): """A simple Python-based callback for testing.""" @@ -38,6 +39,7 @@ def callback(self, policy, inputs, results): assert isinstance(inp, auto_scheduler.MeasureInput) assert isinstance(res, auto_scheduler.MeasureResult) + def search_common( workload=matmul_auto_scheduler_test, target="llvm", From 428e45b8034befbfc82ccd7e93c4fbfe2c9a2b9e Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Thu, 24 Dec 2020 00:27:05 +0000 Subject: [PATCH 8/9] fix ci --- tests/python/unittest/test_auto_scheduler_search_policy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_auto_scheduler_search_policy.py b/tests/python/unittest/test_auto_scheduler_search_policy.py index 97448bc1984a..6d4fb6884ff9 100644 --- a/tests/python/unittest/test_auto_scheduler_search_policy.py +++ b/tests/python/unittest/test_auto_scheduler_search_policy.py @@ -34,7 +34,7 @@ class CustomMeasureCallback(auto_scheduler.measure.PythonBasedMeasureCallback): """A simple Python-based callback for testing.""" def callback(self, policy, inputs, results): - assert isinstance(policy, auto_scheduler.search_policy.SketchPolicy) + assert isinstance(policy, auto_scheduler.search_policy.SearchPolicy) for inp, res in zip(inputs, results): assert isinstance(inp, auto_scheduler.MeasureInput) assert isinstance(res, auto_scheduler.MeasureResult) From 16908f67580b42b6d579a81d3f7ee67d18054dd5 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Thu, 24 Dec 2020 00:08:57 -0800 Subject: [PATCH 9/9] Delete useless old code --- tests/python/relay/test_auto_scheduler_tuning.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/tests/python/relay/test_auto_scheduler_tuning.py b/tests/python/relay/test_auto_scheduler_tuning.py index 9b6c6f2c263a..d42373c86626 100644 --- a/tests/python/relay/test_auto_scheduler_tuning.py +++ b/tests/python/relay/test_auto_scheduler_tuning.py @@ -23,15 +23,6 @@ from test_auto_scheduler_task_extraction import get_network -class CustomMeasureCallback(auto_scheduler.measure.PythonBasedMeasureCallback): - """A simple Python-based callback for testing.""" - - def callback(self, policy, inputs, results): - assert isinstance(policy, auto_scheduler.search_policy.SketchPolicy) - for inp, res in zip(inputs, results): - print(policy, inp, res) - - def tune_network(network, target): # Extract tasks mod, params = get_network(network)