-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[microTVM] Add support for AutoTVM #8715
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a9a67ef
14678ab
98f3468
687d4d5
a091433
f3fb2c0
0ffb0c3
cffa8f9
26e78ca
3d97f6a
142bc0e
8678346
ef0d331
b739908
cab83a9
1d09e96
9d79ef7
367525c
2da7b1b
c4efe81
b340703
adcfe4b
226082d
1780f91
8e39635
0ea7bd7
ac61a56
9e32c8e
8de0606
7853a4f
891c91d
9cbfae1
57dd7c1
c3f287f
82a32ca
90317d4
5d85120
990e9c9
7f19d44
6df9658
76c8396
2b09fac
1afb549
fd22020
80fb240
6791644
4e631f1
97c2cc4
387103e
099a493
42cf68f
93f46a2
f798111
ca773a4
530e6ae
49951e1
b497a80
fe7a03a
f4155cd
a8539dc
8503086
12aa292
31372fd
69bf13f
72ae64f
539e1bc
818f822
9ccbbb2
b22dd00
38ba16b
86768ff
274e482
0205908
6889ea1
c14d7b4
b0b44f6
ff49394
38509fa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,14 +17,17 @@ | |
|
|
||
| """Defines a top-level glue class that operates the Transport and Flasher classes.""" | ||
|
|
||
| import json | ||
| import logging | ||
| import sys | ||
|
|
||
| from ..error import register_error | ||
| from .._ffi import get_global_func | ||
| from .._ffi import get_global_func, register_func | ||
| from ..contrib import graph_executor | ||
| from ..contrib import utils | ||
| from ..contrib.debugger import debug_executor | ||
| from ..rpc import RPCSession | ||
| from . import project | ||
| from .transport import IoTimeoutError | ||
| from .transport import TransportLogger | ||
|
|
||
|
|
@@ -234,3 +237,71 @@ def create_local_debug_executor(graph_json_str, mod, device, dump_root=None): | |
| graph_json_str, | ||
| dump_root=dump_root, | ||
| ) | ||
|
|
||
|
|
||
| RPC_SESSION = None | ||
|
|
||
|
|
||
| @register_func("tvm.micro.compile_and_create_micro_session") | ||
| def compile_and_create_micro_session( | ||
| mod_src_bytes: bytes, | ||
| template_project_dir: str, | ||
| project_options: dict = None, | ||
| ): | ||
| """Compile the given libraries and sources into a MicroBinary, then invoke create_micro_session. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| mod_src_bytes : bytes | ||
| The content of a tarfile which contains the TVM-generated sources which together form the | ||
| SystemLib. This tar is expected to be created by export_library. The tar will be extracted | ||
| into a directory and the sources compiled into a MicroLibrary using the Compiler. | ||
|
|
||
| template_project_dir: str | ||
| The path to a template microTVM Project API project which is used to generate the embedded | ||
| project that is built and flashed onto the target device. | ||
|
|
||
| project_options: dict | ||
| Options for the microTVM API Server contained in template_project_dir. | ||
| """ | ||
| global RPC_SESSION | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe you can avoid the global here by returning an object whose destructor closes the session. Is there a reason this would not work?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm you mean
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep. I'm not sure clear on the lifetime of the RPC_SESSION object though. It looks like the destroyer for it is never called? |
||
|
|
||
| temp_dir = utils.tempdir() | ||
| # Keep temp directory for generate project | ||
| temp_dir.set_keep_for_debug(True) | ||
mehrdadh marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| model_library_format_path = temp_dir / "model.tar.gz" | ||
| with open(model_library_format_path, "wb") as mlf_f: | ||
| mlf_f.write(mod_src_bytes) | ||
|
|
||
| try: | ||
| template_project = project.TemplateProject.from_directory(template_project_dir) | ||
| generated_project = template_project.generate_project_from_mlf( | ||
| model_library_format_path, | ||
| temp_dir / "generated-project", | ||
| options=json.loads(project_options), | ||
| ) | ||
| except Exception as exception: | ||
| logging.error("Project Generate Error: %s", str(exception)) | ||
| raise exception | ||
|
|
||
| generated_project.build() | ||
| generated_project.flash() | ||
| transport = generated_project.transport() | ||
|
|
||
| RPC_SESSION = Session(transport_context_manager=transport) | ||
| RPC_SESSION.__enter__() | ||
| return RPC_SESSION._rpc._sess | ||
|
|
||
|
|
||
| @register_func | ||
| def destroy_micro_session(): | ||
| """Destroy RPC session for microTVM autotune.""" | ||
| global RPC_SESSION | ||
|
|
||
| if RPC_SESSION is not None: | ||
| exc_type, exc_value, traceback = RPC_SESSION.__exit__(None, None, None) | ||
| RPC_SESSION = None | ||
| if (exc_type, exc_value, traceback) != (None, None, None): | ||
| exc = exc_type(exc_value) # See PEP 3109 | ||
| exc.__traceback__ = traceback | ||
| raise exc | ||
Uh oh!
There was an error while loading. Please reload this page.