diff --git a/include/tvm/auto_scheduler/measure_record.h b/include/tvm/auto_scheduler/measure_record.h index fa8fe2b1b455..4d7952f74b40 100755 --- a/include/tvm/auto_scheduler/measure_record.h +++ b/include/tvm/auto_scheduler/measure_record.h @@ -34,6 +34,8 @@ namespace tvm { namespace auto_scheduler { +const std::string AUTO_SCHEDULER_LOG_VERSION = "v0.4"; // NOLINT(*) + /*! \brief Callback for logging the input and results of measurements to file */ class RecordToFileNode : public MeasureCallbackNode { public: @@ -116,9 +118,11 @@ class RecordReader : public ObjectRef { * \param os A pointer to a output stream. * \param inputs The MeasureInputs to be written. * \param results The MeasureResults to be written. + * \param log_version The log version for the given record. */ void WriteMeasureRecords(std::ostream* os, const Array& inputs, - const Array& results); + const Array& results, + const std::string log_version = AUTO_SCHEDULER_LOG_VERSION); /*! * \brief Read one measure record from a string. diff --git a/python/tvm/auto_scheduler/measure_record.py b/python/tvm/auto_scheduler/measure_record.py index d6fea5c48598..35e5e9b68a43 100644 --- a/python/tvm/auto_scheduler/measure_record.py +++ b/python/tvm/auto_scheduler/measure_record.py @@ -98,6 +98,43 @@ def __iter__(self): yield ret[0], ret[1] # (input, result) +def load_record_from_string(record): + """ + Load the measure record from string. + + Parameters + ---------- + record: str + A record string, including the serialized MeausreInput and MeasureResult. + + Returns + ------- + ret: Tuple[MeasureInput, MeasureResult] + A tuple of MeasureInput, MeasureResult. + """ + return _ffi_api.ReadMeasureRecord(record) + + +def dump_record_to_string(inp, res): + """ + Dump the measure record to a string. + + Parameters + ---------- + inp: MeasureInput + The measure input. + + res: MeasureResult + The measure result. + + Returns + ------- + ret: str + The dumped string. + """ + return _ffi_api.WriteMeasureRecords(inp, res) + + def load_records(filename): """ Load measurement records from a file. diff --git a/src/auto_scheduler/measure_record.cc b/src/auto_scheduler/measure_record.cc index aad0abe88308..faf3fca4cfc4 100644 --- a/src/auto_scheduler/measure_record.cc +++ b/src/auto_scheduler/measure_record.cc @@ -279,8 +279,6 @@ namespace auto_scheduler { TVM_REGISTER_OBJECT_TYPE(RecordToFileNode); TVM_REGISTER_OBJECT_TYPE(RecordReaderNode); -const std::string AUTO_SCHEDULER_LOG_VERSION = "v0.4"; // NOLINT(*) - RecordToFile::RecordToFile(String filename) { auto node = make_object(); node->filename = std::move(filename); @@ -288,13 +286,13 @@ RecordToFile::RecordToFile(String filename) { } void WriteMeasureRecords(std::ostream* os, const Array& inputs, - const Array& results) { + const Array& results, const std::string log_version) { dmlc::JSONWriter writer(os); for (size_t i = 0; i < inputs.size(); ++i) { writer.BeginObject(false); writer.WriteObjectKeyValue("i", *inputs[i].operator->()); writer.WriteObjectKeyValue("r", *results[i].operator->()); - writer.WriteObjectKeyValue("v", AUTO_SCHEDULER_LOG_VERSION); + writer.WriteObjectKeyValue("v", log_version); writer.EndObject(); *os << "\n"; } @@ -398,6 +396,23 @@ TVM_REGISTER_GLOBAL("auto_scheduler.RecordReaderReadNext").set_body_typed([](Rec } }); +TVM_REGISTER_GLOBAL("auto_scheduler.ReadMeasureRecord").set_body_typed([](const std::string& str) { + auto inp = make_object(); + auto res = make_object(); + std::string log_version; + ReadMeasureRecord(str, inp.get(), res.get(), &log_version); + return Array{ObjectRef(inp), ObjectRef(res)}; +}); + +TVM_REGISTER_GLOBAL("auto_scheduler.WriteMeasureRecords") + .set_body_typed([](MeasureInput inp, MeasureResult res) { + auto inps = Array({inp}); + auto ress = Array({res}); + std::ostringstream ss; + WriteMeasureRecords(&ss, inps, ress); + return String(ss.str()); + }); + TVM_REGISTER_GLOBAL("auto_scheduler.SaveRecords") .set_body_typed([](String filename, Array in, Array res) { std::ofstream ofs(filename, std::ofstream::app); diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py index 10bb0b4ee276..e9f1fa40c8b3 100644 --- a/tests/python/unittest/test_auto_scheduler_measure.py +++ b/tests/python/unittest/test_auto_scheduler_measure.py @@ -34,11 +34,19 @@ def record_common(dag, s): inp = auto_scheduler.measure.MeasureInput(task, s) res = auto_scheduler.measure.MeasureResult([0.1], 0, "", 0.2, 1) + # Test in-memory record processing. + record_str = auto_scheduler.measure_record.dump_record_to_string(inp, res) + r_inp, r_res = auto_scheduler.measure_record.load_record_from_string(record_str) + # Only check the workload_key for simplification. + assert inp.task.workload_key == r_inp.task.workload_key + assert str(res) == str(r_res) + + # Test file-based record processing. with tempfile.NamedTemporaryFile() as fp: auto_scheduler.save_records(fp.name, [inp], [res]) log_reader = auto_scheduler.RecordReader(fp.name) - inputs, results = log_reader.read_lines() + inputs, _ = log_reader.read_lines() assert len(inputs) == 1 s1 = dag.infer_bound_from_state(s) @@ -180,7 +188,7 @@ def test_recover_measure_input(): auto_scheduler.save_records(fp.name, [inp], [res]) log_reader = auto_scheduler.RecordReader(fp.name) - inputs, results = log_reader.read_lines() + inputs, _ = log_reader.read_lines() assert len(inputs) == 1 raw_inp = inputs[0] @@ -266,7 +274,7 @@ def test_measure_target_host(): auto_scheduler.save_records(fp.name, [inp], [res]) log_reader = auto_scheduler.RecordReader(fp.name) - inputs, results = log_reader.read_lines() + inputs, _ = log_reader.read_lines() assert len(inputs) == 1 raw_inp = inputs[0]