From 0ec2053fa6b27e5798f436e09724d4ba554ebeb4 Mon Sep 17 00:00:00 2001 From: Alexey Voronov Date: Wed, 14 Dec 2022 00:49:47 +0000 Subject: [PATCH] [Metaschedule] Align get_top_k logic in MemoryDatabase and JSONDatabase --- src/meta_schedule/database/json_database.cc | 10 ++++- src/meta_schedule/database/memory_database.cc | 12 +++++- .../unittest/test_meta_schedule_database.py | 39 +++++++++++++++++++ 3 files changed, 57 insertions(+), 4 deletions(-) diff --git a/src/meta_schedule/database/json_database.cc b/src/meta_schedule/database/json_database.cc index bd5183f0cf60..22d6ec849c5f 100644 --- a/src/meta_schedule/database/json_database.cc +++ b/src/meta_schedule/database/json_database.cc @@ -126,16 +126,22 @@ class JSONDatabaseNode : public DatabaseNode { } Array results; results.reserve(top_k); - int counter = 0; for (const TuningRecord& record : this->tuning_records_) { + if (!record->run_secs.defined() || record->run_secs.value().empty()) { + continue; + } if (record->workload.same_as(workload) || WorkloadEqual(GetModuleEquality())(record->workload, workload)) { results.push_back(record); - if (++counter == top_k) { + if (results.size() == static_cast(top_k)) { break; } } } + if (results.size() < static_cast(top_k)) { + LOG(WARNING) << "The size of the GetTopK result is smaller than requested. There are not " + "enough valid records in the database for this workload."; + } return results; } diff --git a/src/meta_schedule/database/memory_database.cc b/src/meta_schedule/database/memory_database.cc index 24fba6dfa105..19178a35f456 100644 --- a/src/meta_schedule/database/memory_database.cc +++ b/src/meta_schedule/database/memory_database.cc @@ -61,8 +61,12 @@ class MemoryDatabaseNode : public DatabaseNode { void CommitTuningRecord(const TuningRecord& record) final { records.push_back(record); } Array GetTopK(const Workload& workload, int top_k) final { + CHECK_GE(top_k, 0) << "ValueError: top_k must be non-negative"; + if (top_k == 0) { + return {}; + } std::vector> results; - results.reserve(this->records.size()); + results.reserve(records.size()); for (const TuningRecord& record : records) { if (!record->run_secs.defined()) { continue; @@ -83,7 +87,7 @@ class MemoryDatabaseNode : public DatabaseNode { std::sort(results.begin(), results.end()); auto begin = results.begin(); auto end = results.end(); - if (static_cast(results.size()) > top_k) { + if (results.size() > static_cast(top_k)) { end = begin + top_k; } Array ret; @@ -92,6 +96,10 @@ class MemoryDatabaseNode : public DatabaseNode { ret.push_back(begin->second); ++begin; } + if (ret.size() < static_cast(top_k)) { + LOG(WARNING) << "The size of the GetTopK result is smaller than requested. There are not " + "enough valid records in the database for this workload."; + } return ret; } diff --git a/tests/python/unittest/test_meta_schedule_database.py b/tests/python/unittest/test_meta_schedule_database.py index 777c5589a141..4ec10b556c3b 100644 --- a/tests/python/unittest/test_meta_schedule_database.py +++ b/tests/python/unittest/test_meta_schedule_database.py @@ -18,6 +18,7 @@ """Test Meta Schedule Database""" import os.path as osp import tempfile +import pytest from typing import Callable, Optional, List import tvm @@ -536,5 +537,43 @@ def test_meta_schedule_pydatabase_current(): assert ms.database.Database.current() == db +def call_get_top_k(run_secs_list, database, k): + mod: IRModule = Matmul + workload = database.commit_workload(mod) + for run_secs in run_secs_list: + record = ms.database.TuningRecord( + _create_schedule(mod, _schedule_matmul).trace, + workload, + run_secs, + tvm.target.Target("llvm"), + ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]), + ) + database.commit_tuning_record(record) + return [[v.value for v in record.run_secs] for record in database.get_top_k(workload, k)] + + +@pytest.mark.parametrize( + "k,expected", + [(0, []), (3, [[0.0, 2.0], [2.0], [1.5, 4.5]]), (5, [[0.0, 2.0], [2.0], [1.5, 4.5]])], +) +def test_memory_database_get_top_k(k, expected): + run_secs_list = [[1.5, 4.5], [], [0.0, 2.0], None, [2.0]] + database = ms.database.MemoryDatabase() + result = call_get_top_k(run_secs_list, database, k) + assert result == expected + + +@pytest.mark.parametrize( + "k,expected", + [(0, []), (3, [[0.0, 2.0], [2.0], [1.5, 4.5]]), (5, [[0.0, 2.0], [2.0], [1.5, 4.5]])], +) +def test_json_database_get_top_k(k, expected): + run_secs_list = [[1.5, 4.5], [], [0.0, 2.0], None, [2.0]] + with tempfile.TemporaryDirectory() as tmpdir: + database = _create_tmp_database(tmpdir) + result = call_get_top_k(run_secs_list, database, k) + assert result == expected + + if __name__ == "__main__": tvm.testing.main()