Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 50 additions & 4 deletions tests/python/unittest/test_meta_schedule_relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Integration test for MetaSchedule"""
from typing import List
import tempfile
import numpy as np
import pytest
Expand All @@ -26,6 +27,7 @@
from tvm._ffi import register_func
from tvm.contrib import graph_executor
from tvm.ir.transform import PassContext
from tvm.meta_schedule.database import Workload, TuningRecord
from tvm.meta_schedule.testing.relay_workload import get_network
from tvm.meta_schedule.testing.tlcbench import load_quantized_bert_base
from tvm.meta_schedule.tune_context import _normalize_mod
Expand Down Expand Up @@ -347,7 +349,7 @@ def test_extract_task_arm_conv2d_nchwc():
assert list(out_type.shape) == [1, 8, 130, 130, 4]


def test_meta_schedule_te2primfunc_argument_order():
def test_meta_schedule_te2primfunc_argument_order_and_lowering():
# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument
# fmt: off
@tvm.script.ir_module
Expand Down Expand Up @@ -416,8 +418,52 @@ def main(placeholder: T.Buffer[(1, 1, 16, 16, 3), "float32"], placeholder_1: T.B
# fmt: on
# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument

def _create_database():
database = ms.database.create("memory")
def _create_verification_database():
@ms.derived_object
class VerificationDatabase(ms.database.PyDatabase):
def __init__(self):
super().__init__()
self.tuning_records_: List[TuningRecord] = []
self.workloads_: List[Workload] = []

def has_workload(self, mod: IRModule) -> bool:
for workload in self.workloads_:
if tvm.ir.structural_equal(mod, workload.mod):
return True
# Note: The database has already put in all correct workloads
# This is where we can check if the workload is correct
raise ValueError(
"The workload searched for is not in given database!"
+ " Incorrect TIR was generated from TE subgraph."
)

def commit_workload(self, mod: IRModule) -> ms.database.Workload:
# No need to deduplicate workload because they are specified
workload = ms.database.Workload(mod)
self.workloads_.append(workload)
return workload

def commit_tuning_record(self, record: TuningRecord) -> None:
self.tuning_records_.append(record)

def get_all_tuning_records(self) -> List[TuningRecord]:
return self.tuning_records_

def get_top_k(self, workload: ms.database.Workload, top_k: int) -> List[TuningRecord]:
return sorted(
list(
filter(
lambda x: tvm.ir.structural_equal(workload.mod, x.workload.mod),
self.tuning_records_,
)
),
key=lambda x: sum(x.run_secs) / len(x.run_secs) if x.run_secs else 1e9,
)[:top_k]

def __len__(self) -> int:
return len(self.tuning_records_)

database = VerificationDatabase()

def _commit(mod):
workload = database.commit_workload(mod)
Expand Down Expand Up @@ -464,7 +510,7 @@ def _create_relay_mod():
dev,
)

with target, _create_database(), PassContext(
with target, _create_verification_database(), PassContext( # pylint: disable=not-context-manager
opt_level=3,
config={
"relay.backend.use_meta_schedule": True,
Expand Down