Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions python/tvm/micro/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import json
import logging
import os
import pathlib
import contextlib
import enum

Expand Down Expand Up @@ -115,23 +114,40 @@ class AutoTvmModuleLoader:

Parameters
----------
template_project_dir : Union[pathlib.Path, str]
template_project_dir : Union[os.PathLike, str]
project template path

project_options : dict
project generation option

project_dir: str
if use_existing is False: The path to save the generated microTVM Project.
if use_existing is True: The path to a generated microTVM Project for debugging.

use_existing: bool
skips the project generation and opens transport to the project at the project_dir address.
"""

def __init__(
self, template_project_dir: Union[pathlib.Path, str], project_options: dict = None
self,
template_project_dir: Union[os.PathLike, str],
project_options: dict = None,
project_dir: Union[os.PathLike, str] = None,
use_existing: bool = False,
):
self._project_options = project_options
self._use_existing = use_existing

if isinstance(template_project_dir, (pathlib.Path, str)):
if isinstance(template_project_dir, (os.PathLike, str)):
self._template_project_dir = str(template_project_dir)
elif not isinstance(template_project_dir, str):
raise TypeError(f"Incorrect type {type(template_project_dir)}.")

if isinstance(project_dir, (os.PathLike, str)):
self._project_dir = str(project_dir)
else:
self._project_dir = None

@contextlib.contextmanager
def __call__(self, remote_kw, build_result):
with open(build_result.filename, "rb") as build_file:
Expand All @@ -147,6 +163,8 @@ def __call__(self, remote_kw, build_result):
build_result_bin,
self._template_project_dir,
json.dumps(self._project_options),
self._project_dir,
self._use_existing,
],
)
system_lib = remote.get_function("runtime.SystemLib")()
Expand Down
56 changes: 40 additions & 16 deletions python/tvm/micro/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
import json
import logging
import sys

import os
import pathlib
import shutil
from typing import Union
from ..error import register_error
from .._ffi import get_global_func, register_func
from ..contrib import graph_executor
Expand Down Expand Up @@ -259,6 +262,8 @@ def compile_and_create_micro_session(
mod_src_bytes: bytes,
template_project_dir: str,
project_options: dict = None,
project_dir: Union[os.PathLike, str] = None,
use_existing: bool = False,
):
"""Compile the given libraries and sources into a MicroBinary, then invoke create_micro_session.

Expand All @@ -275,25 +280,44 @@ def compile_and_create_micro_session(

project_options: dict
Options for the microTVM API Server contained in template_project_dir.
"""

temp_dir = utils.tempdir()
# Keep temp directory for generate project
temp_dir.set_keep_for_debug(True)
model_library_format_path = temp_dir / "model.tar.gz"
with open(model_library_format_path, "wb") as mlf_f:
mlf_f.write(mod_src_bytes)
project_dir: Union[os.PathLike, str]
if use_existing is False: The path to save the generated microTVM Project.
if use_existing is True: The path to a generated microTVM Project for debugging.

try:
template_project = project.TemplateProject.from_directory(template_project_dir)
generated_project = template_project.generate_project_from_mlf(
model_library_format_path,
str(temp_dir / "generated-project"),
use_existing: bool
skips the project generation and opens transport to the project at the project_dir address.
"""

if use_existing:
project_dir = pathlib.Path(project_dir)
assert project_dir.is_dir(), f"{project_dir} does not exist."
build_dir = project_dir / "generated-project" / "build"
shutil.rmtree(build_dir)
generated_project = project.GeneratedProject.from_directory(
project_dir / "generated-project",
options=json.loads(project_options),
)
except Exception as exception:
logging.error("Project Generate Error: %s", str(exception))
raise exception
else:
if project_dir:
temp_dir = utils.tempdir(custom_path=project_dir, keep_for_debug=True)
else:
temp_dir = utils.tempdir()

model_library_format_path = temp_dir / "model.tar.gz"
with open(model_library_format_path, "wb") as mlf_f:
mlf_f.write(mod_src_bytes)

try:
template_project = project.TemplateProject.from_directory(template_project_dir)
generated_project = template_project.generate_project_from_mlf(
model_library_format_path,
str(temp_dir / "generated-project"),
options=json.loads(project_options),
)
except Exception as exception:
logging.error("Project Generate Error: %s", str(exception))
raise exception

generated_project.build()
generated_project.flash()
Expand Down
39 changes: 26 additions & 13 deletions python/tvm/micro/testing/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from pathlib import Path
from contextlib import ExitStack
import tempfile
import shutil

import tvm
from tvm.relay.op.contrib import cmsisnn
Expand All @@ -53,6 +54,7 @@ def tune_model(
"project_type": "host_driven",
**(project_options or {}),
}

module_loader = tvm.micro.AutoTvmModuleLoader(
template_project_dir=tvm.micro.get_microtvm_template_projects(platform),
project_options=project_options,
Expand Down Expand Up @@ -99,6 +101,7 @@ def create_aot_session(
timeout_override=None,
use_cmsis_nn=False,
project_options=None,
use_existing=False,
):
"""AOT-compiles and uploads a model to a microcontroller, and returns the RPC session"""

Expand All @@ -125,21 +128,31 @@ def create_aot_session(
parameter_size = len(tvm.runtime.save_param_dict(lowered.get_params()))
print(f"Model parameter size: {parameter_size}")

project = tvm.micro.generate_project(
str(tvm.micro.get_microtvm_template_projects(platform)),
lowered,
build_dir / "project",
{
f"{platform}_board": board,
"project_type": "host_driven",
# {} shouldn't be the default value for project options ({}
# is mutable), so we use this workaround
**(project_options or {}),
},
)
project_options = {
f"{platform}_board": board,
"project_type": "host_driven",
# {} shouldn't be the default value for project options ({}
# is mutable), so we use this workaround
**(project_options or {}),
}

if use_existing:
shutil.rmtree(build_dir / "project" / "build")
project = tvm.micro.GeneratedProject.from_directory(
build_dir / "project",
options=project_options,
)

else:
project = tvm.micro.generate_project(
str(tvm.micro.get_microtvm_template_projects(platform)),
lowered,
build_dir / "project",
project_options,
)

project.build()
project.flash()

return tvm.micro.Session(project.transport(), timeout_override=timeout_override)


Expand Down