Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/builder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@
Meta Schedule builders that translate IRModule to runtime.Module,
and then export
"""
from .builder import Builder, BuilderInput, BuilderResult, PyBuilder
from .builder import Builder, BuilderInput, BuilderResult, PyBuilder, create
from .local_builder import LocalBuilder
17 changes: 17 additions & 0 deletions python/tvm/meta_schedule/builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
"""Meta Schedule builders that translate IRModule to runtime.Module, and then export"""
from typing import Callable, Dict, List, Optional

# isort: off
from typing_extensions import Literal

# isort: on
from tvm._ffi import register_object
from tvm.ir import IRModule
from tvm.runtime import NDArray, Object
Expand Down Expand Up @@ -164,3 +168,16 @@ def build(self, build_inputs: List[BuilderInput]) -> List[BuilderResult]:
The results of building the given inputs.
"""
raise NotImplementedError


def create( # pylint: disable=keyword-arg-before-vararg
kind: Literal["local"] = "local",
*args,
**kwargs,
) -> Builder:
"""Create a Builder."""
from . import LocalBuilder # pylint: disable=import-outside-toplevel

if kind == "local":
return LocalBuilder(*args, **kwargs) # type: ignore
raise ValueError(f"Unknown Builder: {kind}")
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/database/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
The tvm.meta_schedule.database package.
The database that stores serialized tuning records and workloads
"""
from .database import Database, PyDatabase, TuningRecord, Workload
from .database import Database, PyDatabase, TuningRecord, Workload, create
from .json_database import JSONDatabase
from .memory_database import MemoryDatabase
from .ordered_union_database import OrderedUnionDatabase
Expand Down
41 changes: 40 additions & 1 deletion python/tvm/meta_schedule/database/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,16 @@
"""TuningRecord database"""
from typing import Any, Callable, List, Optional, Union

# isort: off
from typing_extensions import Literal

# isort: on

from tvm._ffi import register_object
from tvm.ir.module import IRModule
from tvm.runtime import Object
from tvm.target import Target
from tvm.tir.schedule import Schedule, Trace
from typing_extensions import Literal # pylint: disable=wrong-import-order

from .. import _ffi_api
from ..arg_info import ArgInfo
Expand Down Expand Up @@ -483,3 +487,38 @@ def __len__(self) -> int:
The number of records in the database
"""
raise NotImplementedError


def create( # pylint: disable=keyword-arg-before-vararg
kind: Union[
Literal[
"json",
"memory",
"union",
"ordered_union",
],
Callable[[Schedule], bool],
] = "json",
*args,
**kwargs,
) -> Database:
"""Create a Database."""
from . import ( # pylint: disable=import-outside-toplevel
JSONDatabase,
MemoryDatabase,
OrderedUnionDatabase,
ScheduleFnDatabase,
UnionDatabase,
)

if callable(kind):
return ScheduleFnDatabase(kind, *args, **kwargs) # type: ignore
if kind == "json":
return JSONDatabase(*args, **kwargs)
if kind == "memory":
return MemoryDatabase(*args, **kwargs) # type: ignore
if kind == "union":
return UnionDatabase(*args, **kwargs) # type: ignore
if kind == "ordered_union":
return OrderedUnionDatabase(*args, **kwargs) # type: ignore
raise ValueError(f"Unknown Database: {kind}")
31 changes: 25 additions & 6 deletions python/tvm/meta_schedule/database/json_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
# specific language governing permissions and limitations
# under the License.
"""The default database that uses a JSON File to store tuning records"""
import os.path as osp
from typing import Optional

from tvm._ffi import register_object

from .. import _ffi_api
Expand All @@ -38,21 +41,37 @@ class JSONDatabase(Database):

def __init__(
self,
path_workload: str,
path_tuning_record: str,
path_workload: Optional[str] = None,
path_tuning_record: Optional[str] = None,
*,
work_dir: Optional[str] = None,
allow_missing: bool = True,
) -> None:
"""Constructor.

Parameters
----------
path_workload : str
The path to the workload table.
path_tuning_record : str
The path to the tuning record table.
path_workload : Optional[str] = None
The path to the workload table. If not specified,
will be generated from `work_dir` as `$work_dir/database_workload.json`.
path_tuning_record : Optional[str] = None
The path to the tuning record table. If not specified,
will be generated from `work_dir` as `$work_dir/database_tuning_record.json`.
work_dir : Optional[str] = None
The work directory, if specified, will be used to generate `path_tuning_record`
and `path_workload`.
allow_missing : bool
Whether to create new file when the given path is not found.
"""
if work_dir is not None:
if path_workload is None:
path_workload = osp.join(work_dir, "database_workload.json")
if path_tuning_record is None:
path_tuning_record = osp.join(work_dir, "database_tuning_record.json")
if path_workload is None:
raise ValueError("`path_workload` is not specified.")
if path_tuning_record is None:
raise ValueError("`path_tuning_record` is not specified.")
self.__init_handle_by_constructor__(
_ffi_api.DatabaseJSONDatabase, # type: ignore # pylint: disable=no-member
path_workload,
Expand Down
12 changes: 10 additions & 2 deletions python/tvm/meta_schedule/runner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@
Meta Schedule runners that runs an artifact either locally or through the RPC interface
"""
from .config import EvaluatorConfig, RPCConfig
from .rpc_runner import RPCRunner
from .local_runner import LocalRunner, LocalRunnerFuture
from .runner import PyRunner, Runner, RunnerFuture, RunnerInput, RunnerResult, PyRunnerFuture
from .rpc_runner import RPCRunner
from .runner import (
PyRunner,
PyRunnerFuture,
Runner,
RunnerFuture,
RunnerInput,
RunnerResult,
create,
)
22 changes: 21 additions & 1 deletion python/tvm/meta_schedule/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
# specific language governing permissions and limitations
# under the License.
"""Runners"""
from typing import Callable, Optional, List
from typing import Callable, List, Optional

# isort: off
from typing_extensions import Literal

# isort: on

from tvm._ffi import register_object
from tvm.runtime import Object
Expand Down Expand Up @@ -223,3 +228,18 @@ def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]:
The runner futures.
"""
raise NotImplementedError


def create( # pylint: disable=keyword-arg-before-vararg
kind: Literal["local", "rpc"] = "local",
*args,
**kwargs,
) -> Runner:
"""Create a Runner."""
from . import LocalRunner, RPCRunner # pylint: disable=import-outside-toplevel

if kind == "local":
return LocalRunner(*args, **kwargs) # type: ignore
elif kind == "rpc":
return RPCRunner(*args, **kwargs) # type: ignore
raise ValueError(f"Unknown Runner: {kind}")
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/search_strategy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@
from .evolutionary_search import EvolutionarySearch
from .replay_func import ReplayFunc
from .replay_trace import ReplayTrace
from .search_strategy import MeasureCandidate, PySearchStrategy, SearchStrategy
from .search_strategy import MeasureCandidate, PySearchStrategy, SearchStrategy, create
29 changes: 29 additions & 0 deletions python/tvm/meta_schedule/search_strategy/search_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
"""
from typing import TYPE_CHECKING, Callable, List, Optional

# isort: off
from typing_extensions import Literal

# isort: on
from tvm._ffi import register_object
from tvm.runtime import Object
from tvm.tir.schedule import Schedule
Expand Down Expand Up @@ -245,3 +249,28 @@ def notify_runner_results(
The profiling results from the runner.
"""
raise NotImplementedError


def create( # pylint: disable=keyword-arg-before-vararg
kind: Literal[
"evolutionary",
"replay_trace",
"replay_func",
] = "evolutionary",
*args,
**kwargs,
) -> SearchStrategy:
"""Create a search strategy."""
from . import ( # pylint: disable=import-outside-toplevel
EvolutionarySearch,
ReplayFunc,
ReplayTrace,
)

if kind == "evolutionary":
return EvolutionarySearch(*args, **kwargs)
if kind == "replay_trace":
return ReplayTrace(*args, **kwargs)
if kind == "replay_func":
return ReplayFunc(*args, **kwargs)
raise ValueError(f"Unknown SearchStrategy: {kind}")
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/space_generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@
"""
from .post_order_apply import PostOrderApply
from .schedule_fn import ScheduleFn
from .space_generator import PySpaceGenerator, ScheduleFnType, SpaceGenerator
from .space_generator import PySpaceGenerator, ScheduleFnType, SpaceGenerator, create
from .space_generator_union import SpaceGeneratorUnion
28 changes: 28 additions & 0 deletions python/tvm/meta_schedule/space_generator/space_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
"""
from typing import TYPE_CHECKING, Callable, List, Optional, Union

# isort: off
from typing_extensions import Literal

# isort: on
from tvm._ffi import register_object
from tvm.ir import IRModule
from tvm.runtime import Object
Expand Down Expand Up @@ -132,3 +136,27 @@ def generate_design_space(self, mod: IRModule) -> List[Schedule]:
The generated design spaces, i.e., schedules.
"""
raise NotImplementedError


def create( # pylint: disable=keyword-arg-before-vararg
kind: Union[
Literal["post_order_apply", "union"],
ScheduleFnType,
] = "post_order_apply",
*args,
**kwargs,
) -> SpaceGenerator:
"""Create a design space generator."""
from . import ( # pylint: disable=import-outside-toplevel
PostOrderApply,
ScheduleFn,
SpaceGeneratorUnion,
)

if callable(kind):
return ScheduleFn(kind, *args, **kwargs) # type: ignore
if kind == "post_order_apply":
return PostOrderApply(*args, **kwargs)
if kind == "union":
return SpaceGeneratorUnion(*args, **kwargs)
raise ValueError(f"Unknown SpaceGenerator: {kind}")
4 changes: 2 additions & 2 deletions python/tvm/meta_schedule/task_scheduler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@
for measure candidates generation and measurement, then save
records to the database.
"""
from .task_scheduler import TaskScheduler, PyTaskScheduler
from .round_robin import RoundRobin
from .gradient_based import GradientBased
from .round_robin import RoundRobin
from .task_scheduler import PyTaskScheduler, TaskScheduler, create
20 changes: 20 additions & 0 deletions python/tvm/meta_schedule/task_scheduler/task_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@
import logging
from typing import Callable, List, Optional

# isort: off
from typing_extensions import Literal

# isort: on

from tvm._ffi import register_object
from tvm.runtime import Object

Expand Down Expand Up @@ -255,3 +260,18 @@ def touch_task(self, task_id: int) -> None:
"""
# Using self._outer to replace the self pointer
_ffi_api.TaskSchedulerTouchTask(self._outer(), task_id) # type: ignore # pylint: disable=no-member


def create( # pylint: disable=keyword-arg-before-vararg
kind: Literal["round-robin", "gradient"] = "gradient",
*args,
**kwargs,
) -> "TaskScheduler":
"""Create a task scheduler."""
from . import GradientBased, RoundRobin # pylint: disable=import-outside-toplevel

if kind == "round-robin":
return RoundRobin(*args, **kwargs)
if kind == "gradient":
return GradientBased(*args, **kwargs)
raise ValueError(f"Unknown TaskScheduler name: {kind}")
4 changes: 2 additions & 2 deletions python/tvm/meta_schedule/testing/relay_workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def _get_network(
"float32": torch.float32, # pylint: disable=no-member
}[dtype]
)
scripted_model = torch.jit.trace(model, input_data).eval()
scripted_model = torch.jit.trace(model, input_data).eval() # type: ignore
input_name = "input0"
shape_list = [(input_name, input_shape)]
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
Expand Down Expand Up @@ -149,7 +149,7 @@ def _get_network(
input_dtype = "int64"
a = torch.randint(10000, input_shape) # pylint: disable=no-member
model.eval()
scripted_model = torch.jit.trace(model, [a], strict=False)
scripted_model = torch.jit.trace(model, [a], strict=False) # type: ignore
input_name = "input_ids"
shape_list = [(input_name, input_shape)]
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
Expand Down