From 7e66def609c9e389589564959f5e9f34f82740bf Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Tue, 22 Dec 2020 00:09:03 +0000 Subject: [PATCH 1/3] [AutoScheduler] Support string processing to records --- include/tvm/auto_scheduler/measure_record.h | 5 ++- python/tvm/auto_scheduler/measure_record.py | 40 +++++++++++++++++++ src/auto_scheduler/measure_record.cc | 24 +++++++++-- .../unittest/test_auto_scheduler_measure.py | 15 +++++-- 4 files changed, 76 insertions(+), 8 deletions(-) diff --git a/include/tvm/auto_scheduler/measure_record.h b/include/tvm/auto_scheduler/measure_record.h index fa8fe2b1b455..7cfbe8ba7ccd 100755 --- a/include/tvm/auto_scheduler/measure_record.h +++ b/include/tvm/auto_scheduler/measure_record.h @@ -34,6 +34,8 @@ namespace tvm { namespace auto_scheduler { +const std::string AUTO_SCHEDULER_LOG_VERSION = "v0.4"; // NOLINT(*) + /*! \brief Callback for logging the input and results of measurements to file */ class RecordToFileNode : public MeasureCallbackNode { public: @@ -118,7 +120,8 @@ class RecordReader : public ObjectRef { * \param results The MeasureResults to be written. */ void WriteMeasureRecords(std::ostream* os, const Array& inputs, - const Array& results); + const Array& results, + const std::string log_version = AUTO_SCHEDULER_LOG_VERSION); /*! * \brief Read one measure record from a string. diff --git a/python/tvm/auto_scheduler/measure_record.py b/python/tvm/auto_scheduler/measure_record.py index d6fea5c48598..b5b720df585c 100644 --- a/python/tvm/auto_scheduler/measure_record.py +++ b/python/tvm/auto_scheduler/measure_record.py @@ -98,6 +98,46 @@ def __iter__(self): yield ret[0], ret[1] # (input, result) +def load_record_from_string(record): + """ + Load the measure record from string. + + Parameters + ---------- + record: str + A record string, including the serialized MeausreInput and MeasureResult. + + Returns + ------- + ret: Tuple[MeasureInput, MeasureResult, str] + A tuple of MeasureInput, MeasureResult, and the log version. + """ + return _ffi_api.ReadMeasureRecord(record) + + +def dump_record_to_string(inp, res, log_version): + """ + Dump the measure record to a string. + + Parameters + ---------- + inp: MeasureInput + The measure input. + + res: MeasureResult + The measure result. + + log_version: str + The log version of the given record. + + Returns + ------- + ret: str + The dumped string. + """ + return _ffi_api.WriteMeasureRecords(inp, res, log_version) + + def load_records(filename): """ Load measurement records from a file. diff --git a/src/auto_scheduler/measure_record.cc b/src/auto_scheduler/measure_record.cc index aad0abe88308..bc22689af362 100644 --- a/src/auto_scheduler/measure_record.cc +++ b/src/auto_scheduler/measure_record.cc @@ -279,8 +279,6 @@ namespace auto_scheduler { TVM_REGISTER_OBJECT_TYPE(RecordToFileNode); TVM_REGISTER_OBJECT_TYPE(RecordReaderNode); -const std::string AUTO_SCHEDULER_LOG_VERSION = "v0.4"; // NOLINT(*) - RecordToFile::RecordToFile(String filename) { auto node = make_object(); node->filename = std::move(filename); @@ -288,13 +286,13 @@ RecordToFile::RecordToFile(String filename) { } void WriteMeasureRecords(std::ostream* os, const Array& inputs, - const Array& results) { + const Array& results, const std::string log_version) { dmlc::JSONWriter writer(os); for (size_t i = 0; i < inputs.size(); ++i) { writer.BeginObject(false); writer.WriteObjectKeyValue("i", *inputs[i].operator->()); writer.WriteObjectKeyValue("r", *results[i].operator->()); - writer.WriteObjectKeyValue("v", AUTO_SCHEDULER_LOG_VERSION); + writer.WriteObjectKeyValue("v", log_version); writer.EndObject(); *os << "\n"; } @@ -398,6 +396,24 @@ TVM_REGISTER_GLOBAL("auto_scheduler.RecordReaderReadNext").set_body_typed([](Rec } }); +TVM_REGISTER_GLOBAL("auto_scheduler.ReadMeasureRecord").set_body_typed([](const std::string& str) { + auto inp = make_object(); + auto res = make_object(); + std::string log_version; + ReadMeasureRecord(str, inp.get(), res.get(), &log_version); + return Array{ObjectRef(inp), ObjectRef(res), String(log_version)}; +}); + +TVM_REGISTER_GLOBAL("auto_scheduler.WriteMeasureRecords") + .set_body_typed([](MeasureInput inp, MeasureResult res, String log_version) { + auto inps = Array({inp}); + auto ress = Array({res}); + std::ostringstream ss; + + WriteMeasureRecords(&ss, inps, ress, log_version); + return String(ss.str()); + }); + TVM_REGISTER_GLOBAL("auto_scheduler.SaveRecords") .set_body_typed([](String filename, Array in, Array res) { std::ofstream ofs(filename, std::ofstream::app); diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py index 10bb0b4ee276..e492c27b7322 100644 --- a/tests/python/unittest/test_auto_scheduler_measure.py +++ b/tests/python/unittest/test_auto_scheduler_measure.py @@ -34,11 +34,20 @@ def record_common(dag, s): inp = auto_scheduler.measure.MeasureInput(task, s) res = auto_scheduler.measure.MeasureResult([0.1], 0, "", 0.2, 1) + # Test in-memory record processing. + record_str = auto_scheduler.measure_record.dump_record_to_string(inp, res, "v0.4") + r_inp, r_res, r_log_ver = auto_scheduler.measure_record.load_record_from_string(record_str) + # Only check the workload_key for simplification. + assert inp.task.workload_key == r_inp.task.workload_key + assert str(res) == str(r_res) + assert "v0.4" == r_log_ver + + # Test file-based record processing. with tempfile.NamedTemporaryFile() as fp: auto_scheduler.save_records(fp.name, [inp], [res]) log_reader = auto_scheduler.RecordReader(fp.name) - inputs, results = log_reader.read_lines() + inputs, _ = log_reader.read_lines() assert len(inputs) == 1 s1 = dag.infer_bound_from_state(s) @@ -180,7 +189,7 @@ def test_recover_measure_input(): auto_scheduler.save_records(fp.name, [inp], [res]) log_reader = auto_scheduler.RecordReader(fp.name) - inputs, results = log_reader.read_lines() + inputs, _ = log_reader.read_lines() assert len(inputs) == 1 raw_inp = inputs[0] @@ -266,7 +275,7 @@ def test_measure_target_host(): auto_scheduler.save_records(fp.name, [inp], [res]) log_reader = auto_scheduler.RecordReader(fp.name) - inputs, results = log_reader.read_lines() + inputs, _ = log_reader.read_lines() assert len(inputs) == 1 raw_inp = inputs[0] From a03e24776bf94cb705100f32ed1797591d33777b Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Tue, 22 Dec 2020 00:30:18 +0000 Subject: [PATCH 2/3] doc --- include/tvm/auto_scheduler/measure_record.h | 1 + 1 file changed, 1 insertion(+) diff --git a/include/tvm/auto_scheduler/measure_record.h b/include/tvm/auto_scheduler/measure_record.h index 7cfbe8ba7ccd..4d7952f74b40 100755 --- a/include/tvm/auto_scheduler/measure_record.h +++ b/include/tvm/auto_scheduler/measure_record.h @@ -118,6 +118,7 @@ class RecordReader : public ObjectRef { * \param os A pointer to a output stream. * \param inputs The MeasureInputs to be written. * \param results The MeasureResults to be written. + * \param log_version The log version for the given record. */ void WriteMeasureRecords(std::ostream* os, const Array& inputs, const Array& results, From c80857408b3b8e44f86a6503748b4dc70e5f42e0 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Wed, 23 Dec 2020 06:42:21 +0000 Subject: [PATCH 3/3] remove log --- python/tvm/auto_scheduler/measure_record.py | 11 ++++------- src/auto_scheduler/measure_record.cc | 7 +++---- tests/python/unittest/test_auto_scheduler_measure.py | 5 ++--- 3 files changed, 9 insertions(+), 14 deletions(-) diff --git a/python/tvm/auto_scheduler/measure_record.py b/python/tvm/auto_scheduler/measure_record.py index b5b720df585c..35e5e9b68a43 100644 --- a/python/tvm/auto_scheduler/measure_record.py +++ b/python/tvm/auto_scheduler/measure_record.py @@ -109,13 +109,13 @@ def load_record_from_string(record): Returns ------- - ret: Tuple[MeasureInput, MeasureResult, str] - A tuple of MeasureInput, MeasureResult, and the log version. + ret: Tuple[MeasureInput, MeasureResult] + A tuple of MeasureInput, MeasureResult. """ return _ffi_api.ReadMeasureRecord(record) -def dump_record_to_string(inp, res, log_version): +def dump_record_to_string(inp, res): """ Dump the measure record to a string. @@ -127,15 +127,12 @@ def dump_record_to_string(inp, res, log_version): res: MeasureResult The measure result. - log_version: str - The log version of the given record. - Returns ------- ret: str The dumped string. """ - return _ffi_api.WriteMeasureRecords(inp, res, log_version) + return _ffi_api.WriteMeasureRecords(inp, res) def load_records(filename): diff --git a/src/auto_scheduler/measure_record.cc b/src/auto_scheduler/measure_record.cc index bc22689af362..faf3fca4cfc4 100644 --- a/src/auto_scheduler/measure_record.cc +++ b/src/auto_scheduler/measure_record.cc @@ -401,16 +401,15 @@ TVM_REGISTER_GLOBAL("auto_scheduler.ReadMeasureRecord").set_body_typed([](const auto res = make_object(); std::string log_version; ReadMeasureRecord(str, inp.get(), res.get(), &log_version); - return Array{ObjectRef(inp), ObjectRef(res), String(log_version)}; + return Array{ObjectRef(inp), ObjectRef(res)}; }); TVM_REGISTER_GLOBAL("auto_scheduler.WriteMeasureRecords") - .set_body_typed([](MeasureInput inp, MeasureResult res, String log_version) { + .set_body_typed([](MeasureInput inp, MeasureResult res) { auto inps = Array({inp}); auto ress = Array({res}); std::ostringstream ss; - - WriteMeasureRecords(&ss, inps, ress, log_version); + WriteMeasureRecords(&ss, inps, ress); return String(ss.str()); }); diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py index e492c27b7322..e9f1fa40c8b3 100644 --- a/tests/python/unittest/test_auto_scheduler_measure.py +++ b/tests/python/unittest/test_auto_scheduler_measure.py @@ -35,12 +35,11 @@ def record_common(dag, s): res = auto_scheduler.measure.MeasureResult([0.1], 0, "", 0.2, 1) # Test in-memory record processing. - record_str = auto_scheduler.measure_record.dump_record_to_string(inp, res, "v0.4") - r_inp, r_res, r_log_ver = auto_scheduler.measure_record.load_record_from_string(record_str) + record_str = auto_scheduler.measure_record.dump_record_to_string(inp, res) + r_inp, r_res = auto_scheduler.measure_record.load_record_from_string(record_str) # Only check the workload_key for simplification. assert inp.task.workload_key == r_inp.task.workload_key assert str(res) == str(r_res) - assert "v0.4" == r_log_ver # Test file-based record processing. with tempfile.NamedTemporaryFile() as fp: