Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion python/tvm/meta_schedule/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from tvm.target import Target
from tvm.tir import PrimFunc

from .database import Database
from . import _ffi_api


Expand Down Expand Up @@ -174,7 +175,13 @@ def __init__(self) -> None:

@register_object("meta_schedule.ApplyHistoryBest")
class ApplyHistoryBest(MetaScheduleContext):
pass
"""An integration context that allows application of historically best record from database"""

database: Database
""" The database to be queried from"""

def __init__(self, database) -> None:
self.__init_handle_by_constructor__(_ffi_api.ApplyHistoryBest, database) # type: ignore # pylint: disable=no-member


def extract_task(
Expand Down
22 changes: 21 additions & 1 deletion src/meta_schedule/integration.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
#include <tvm/relay/function.h>
#include <tvm/tir/function.h>

#include "./utils.h"

namespace tvm {
namespace meta_schedule {

Expand Down Expand Up @@ -112,7 +114,21 @@ ApplyHistoryBest::ApplyHistoryBest(Database database) {

Optional<ObjectRef> ApplyHistoryBestNode::Query(runtime::String task_name, IRModule mod,
Optional<Array<IRModule>> dispatched) {
throw;
ICHECK(dispatched.defined());
ICHECK_EQ(dispatched.value().size(), 1);
ICHECK(HasOnlyOneFunction<relay::Function>(mod)) << mod;
IRModule prim_mod = dispatched.value()[0];
ICHECK(HasOnlyOneFunction<tir::PrimFunc>(prim_mod)) << prim_mod;
// Unify func name to make sure it can be found in database
prim_mod = UnifyFuncName(prim_mod);
if (database->HasWorkload(prim_mod)) {
Array<TuningRecord> records = database->GetTopK(database->CommitWorkload(prim_mod), 1);
if (records.size() == 1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we do ICHECK_EQ(records.size(), 1) here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. Here I don't think we need to do the check because HasWorkload implies the workload is in the workload registery but does not imply we have a valid tuning record with workload inside of the database. Therefore, I removed such check here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, that makes sense! Thanks for the answer!

LOG(INFO) << "Applied history best for " << task_name << ".";
return records[0]->workload->mod;
}
}
return NullOpt;
}

/**************** FFI ****************/
Expand Down Expand Up @@ -146,6 +162,10 @@ TVM_REGISTER_GLOBAL("meta_schedule.MetaScheduleContextQuery")
TVM_REGISTER_GLOBAL("meta_schedule.TaskExtraction").set_body_typed([]() -> TaskExtraction {
return TaskExtraction();
});
TVM_REGISTER_GLOBAL("meta_schedule.ApplyHistoryBest")
.set_body_typed([](Database database) -> ApplyHistoryBest {
return ApplyHistoryBest(database);
});

} // namespace meta_schedule
} // namespace tvm
16 changes: 16 additions & 0 deletions src/meta_schedule/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,22 @@ inline int GetTargetNumCores(const Target& target) {
return num_cores;
}

/*!
* \brief Unify the function name in workload to "main".
* \param mod The workload.
* \return The new workload with unified function name.
* \note If the name is not unified, the workload may not be found in database.
*/
inline IRModule UnifyFuncName(const IRModule& mod) {
if (!mod->ContainGlobalVar("main") && mod->GetGlobalTypeVars().size() == 1) {
IRModule new_mod = IRModule(
Map<GlobalVar, BaseFunc>({{GlobalVar("main"), mod->functions[mod->GetGlobalVars()[0]]}}));
return new_mod;
} else {
return mod;
}
}

} // namespace meta_schedule
} // namespace tvm

Expand Down
58 changes: 58 additions & 0 deletions tests/python/unittest/test_meta_schedule_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,14 @@
import tvm
from tvm import meta_schedule as ms
from tvm.ir.module import IRModule
from tvm.tir import Schedule
from tvm.target import Target
from tvm.meta_schedule.database import PyDatabase, Workload, TuningRecord
from tvm.meta_schedule.integration import (
ExtractedTask,
MetaScheduleContext,
TaskExtraction,
ApplyHistoryBest,
)
from tvm.meta_schedule.testing import get_network
from tvm.script import tir as T
Expand Down Expand Up @@ -116,5 +120,59 @@ def test_meta_schedule_integration_extract_from_resnet():
assert len(extracted_tasks) == 30


def test_meta_schedule_integration_apply_history_best():
class DummyDatabase(PyDatabase):
def __init__(self):
super().__init__()
self.records = []
self.workload_reg = []

def has_workload(self, mod: IRModule) -> Workload:
for workload in self.workload_reg:
if tvm.ir.structural_equal(workload.mod, mod):
return True
return False

def commit_tuning_record(self, record: TuningRecord) -> None:
self.records.append(record)

def commit_workload(self, mod: IRModule) -> Workload:
for workload in self.workload_reg:
if tvm.ir.structural_equal(workload.mod, mod):
return workload
workload = Workload(mod)
self.workload_reg.append(workload)
return workload

def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]:
return list(
filter(
lambda x: x.workload == workload,
sorted(self.records, key=lambda x: sum(x.run_secs) / len(x.run_secs)),
)
)[: int(top_k)]

def __len__(self) -> int:
return len(self.records)

def print_results(self) -> None:
print("\n".join([str(r) for r in self.records]))

mod, _, _, _ = get_network(
name="resnet-18",
batch_size=1,
layout="NHWC",
dtype="float32",
)
database = DummyDatabase()
env = ApplyHistoryBest(database)
workload = database.commit_workload(MockModule)
database.commit_tuning_record(
TuningRecord(Schedule(MockModule).trace, [1.0], workload, Target("llvm"), [])
)
mod = env.query(task_name="mock-task", mod=mod, dispatched=[MockModule])
assert tvm.ir.structural_equal(mod, workload.mod)


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))