diff --git a/tests/python/unittest/test_meta_schedule_relay_integration.py b/tests/python/unittest/test_meta_schedule_relay_integration.py index e9908cbfde14..9a1c9e8dc7f5 100644 --- a/tests/python/unittest/test_meta_schedule_relay_integration.py +++ b/tests/python/unittest/test_meta_schedule_relay_integration.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Integration test for MetaSchedule""" +from typing import List import tempfile import numpy as np import pytest @@ -26,6 +27,7 @@ from tvm._ffi import register_func from tvm.contrib import graph_executor from tvm.ir.transform import PassContext +from tvm.meta_schedule.database import Workload, TuningRecord from tvm.meta_schedule.testing.relay_workload import get_network from tvm.meta_schedule.testing.tlcbench import load_quantized_bert_base from tvm.meta_schedule.tune_context import _normalize_mod @@ -347,7 +349,7 @@ def test_extract_task_arm_conv2d_nchwc(): assert list(out_type.shape) == [1, 8, 130, 130, 4] -def test_meta_schedule_te2primfunc_argument_order(): +def test_meta_schedule_te2primfunc_argument_order_and_lowering(): # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument # fmt: off @tvm.script.ir_module @@ -416,8 +418,52 @@ def main(placeholder: T.Buffer[(1, 1, 16, 16, 3), "float32"], placeholder_1: T.B # fmt: on # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument - def _create_database(): - database = ms.database.create("memory") + def _create_verification_database(): + @ms.derived_object + class VerificationDatabase(ms.database.PyDatabase): + def __init__(self): + super().__init__() + self.tuning_records_: List[TuningRecord] = [] + self.workloads_: List[Workload] = [] + + def has_workload(self, mod: IRModule) -> bool: + for workload in self.workloads_: + if tvm.ir.structural_equal(mod, workload.mod): + return True + # Note: The database has already put in all correct workloads + # This is where we can check if the workload is correct + raise ValueError( + "The workload searched for is not in given database!" + + " Incorrect TIR was generated from TE subgraph." + ) + + def commit_workload(self, mod: IRModule) -> ms.database.Workload: + # No need to deduplicate workload because they are specified + workload = ms.database.Workload(mod) + self.workloads_.append(workload) + return workload + + def commit_tuning_record(self, record: TuningRecord) -> None: + self.tuning_records_.append(record) + + def get_all_tuning_records(self) -> List[TuningRecord]: + return self.tuning_records_ + + def get_top_k(self, workload: ms.database.Workload, top_k: int) -> List[TuningRecord]: + return sorted( + list( + filter( + lambda x: tvm.ir.structural_equal(workload.mod, x.workload.mod), + self.tuning_records_, + ) + ), + key=lambda x: sum(x.run_secs) / len(x.run_secs) if x.run_secs else 1e9, + )[:top_k] + + def __len__(self) -> int: + return len(self.tuning_records_) + + database = VerificationDatabase() def _commit(mod): workload = database.commit_workload(mod) @@ -464,7 +510,7 @@ def _create_relay_mod(): dev, ) - with target, _create_database(), PassContext( + with target, _create_verification_database(), PassContext( # pylint: disable=not-context-manager opt_level=3, config={ "relay.backend.use_meta_schedule": True,