Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/meta_schedule/database/json_database.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,16 +126,22 @@ class JSONDatabaseNode : public DatabaseNode {
}
Array<TuningRecord> results;
results.reserve(top_k);
int counter = 0;
for (const TuningRecord& record : this->tuning_records_) {
if (!record->run_secs.defined() || record->run_secs.value().empty()) {
continue;
}
if (record->workload.same_as(workload) ||
WorkloadEqual(GetModuleEquality())(record->workload, workload)) {
results.push_back(record);
if (++counter == top_k) {
if (results.size() == static_cast<size_t>(top_k)) {
break;
}
}
}
if (results.size() < static_cast<size_t>(top_k)) {
LOG(WARNING) << "The size of the GetTopK result is smaller than requested. There are not "
"enough valid records in the database for this workload.";
}
return results;
}

Expand Down
12 changes: 10 additions & 2 deletions src/meta_schedule/database/memory_database.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,12 @@ class MemoryDatabaseNode : public DatabaseNode {
void CommitTuningRecord(const TuningRecord& record) final { records.push_back(record); }

Array<TuningRecord> GetTopK(const Workload& workload, int top_k) final {
CHECK_GE(top_k, 0) << "ValueError: top_k must be non-negative";
if (top_k == 0) {
return {};
}
std::vector<std::pair<double, TuningRecord>> results;
results.reserve(this->records.size());
results.reserve(records.size());
for (const TuningRecord& record : records) {
if (!record->run_secs.defined()) {
continue;
Expand All @@ -83,7 +87,7 @@ class MemoryDatabaseNode : public DatabaseNode {
std::sort(results.begin(), results.end());
auto begin = results.begin();
auto end = results.end();
if (static_cast<int>(results.size()) > top_k) {
if (results.size() > static_cast<size_t>(top_k)) {
end = begin + top_k;
}
Array<TuningRecord> ret;
Expand All @@ -92,6 +96,10 @@ class MemoryDatabaseNode : public DatabaseNode {
ret.push_back(begin->second);
++begin;
}
if (ret.size() < static_cast<size_t>(top_k)) {
LOG(WARNING) << "The size of the GetTopK result is smaller than requested. There are not "
"enough valid records in the database for this workload.";
}
return ret;
}

Expand Down
39 changes: 39 additions & 0 deletions tests/python/unittest/test_meta_schedule_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""Test Meta Schedule Database"""
import os.path as osp
import tempfile
import pytest
from typing import Callable, Optional, List

import tvm
Expand Down Expand Up @@ -536,5 +537,43 @@ def test_meta_schedule_pydatabase_current():
assert ms.database.Database.current() == db


def call_get_top_k(run_secs_list, database, k):
mod: IRModule = Matmul
workload = database.commit_workload(mod)
for run_secs in run_secs_list:
record = ms.database.TuningRecord(
_create_schedule(mod, _schedule_matmul).trace,
workload,
run_secs,
tvm.target.Target("llvm"),
ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]),
)
database.commit_tuning_record(record)
return [[v.value for v in record.run_secs] for record in database.get_top_k(workload, k)]


@pytest.mark.parametrize(
"k,expected",
[(0, []), (3, [[0.0, 2.0], [2.0], [1.5, 4.5]]), (5, [[0.0, 2.0], [2.0], [1.5, 4.5]])],
)
def test_memory_database_get_top_k(k, expected):
run_secs_list = [[1.5, 4.5], [], [0.0, 2.0], None, [2.0]]
database = ms.database.MemoryDatabase()
result = call_get_top_k(run_secs_list, database, k)
assert result == expected


@pytest.mark.parametrize(
"k,expected",
[(0, []), (3, [[0.0, 2.0], [2.0], [1.5, 4.5]]), (5, [[0.0, 2.0], [2.0], [1.5, 4.5]])],
)
def test_json_database_get_top_k(k, expected):
run_secs_list = [[1.5, 4.5], [], [0.0, 2.0], None, [2.0]]
with tempfile.TemporaryDirectory() as tmpdir:
database = _create_tmp_database(tmpdir)
result = call_get_top_k(run_secs_list, database, k)
assert result == expected


if __name__ == "__main__":
tvm.testing.main()