Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions python/tvm/micro/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
import logging
import os
import pathlib
import contextlib

from typing import Union
from .._ffi import libinfo
from .. import rpc as _rpc

Expand Down Expand Up @@ -67,21 +69,24 @@ class AutoTvmModuleLoader:

Parameters
----------
template_project_dir : str
template_project_dir : Union[pathlib.Path, str]
project template path

project_options : dict
project generation option
"""

def __init__(self, template_project_dir: str, project_options: dict = None):
def __init__(
self, template_project_dir: Union[pathlib.Path, str], project_options: dict = None
):
self._project_options = project_options

if isinstance(template_project_dir, pathlib.Path):
if isinstance(template_project_dir, (pathlib.Path, str)):
self._template_project_dir = str(template_project_dir)
elif not isinstance(template_project_dir, str):
raise TypeError(f"Incorrect type {type(template_project_dir)}.")

@contextlib.contextmanager
def __call__(self, remote_kw, build_result):
with open(build_result.filename, "rb") as build_file:
build_result_bin = build_file.read()
Expand All @@ -100,10 +105,6 @@ def __call__(self, remote_kw, build_result):
)
system_lib = remote.get_function("runtime.SystemLib")()
yield remote, system_lib
try:
remote.get_function("tvm.micro.destroy_micro_session")()
except tvm.error.TVMError as exception:
_LOG.warning("Error destroying remote session: %s", str(exception), exc_info=1)


def autotvm_build_func():
Expand Down
31 changes: 9 additions & 22 deletions python/tvm/micro/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def __enter__(self):
int(timeouts.session_start_retry_timeout_sec * 1e6),
int(timeouts.session_start_timeout_sec * 1e6),
int(timeouts.session_established_timeout_sec * 1e6),
self._shutdown,
)
)
self.device = self._rpc.cpu(0)
Expand All @@ -143,6 +144,9 @@ def __exit__(self, exc_type, exc_value, exc_traceback):
"""Tear down this session and associated RPC session resources."""
self.transport.__exit__(exc_type, exc_value, exc_traceback)

def _shutdown(self):
self.__exit__(None, None, None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps for a future cleanup: let's ensure we avoid double-calling __exit__ by checking some local var to determine whether we did already.



def lookup_remote_linked_param(mod, storage_id, template_tensor, device):
"""Lookup a parameter that has been pre-linked into a remote (i.e. over RPC) Module.
Expand Down Expand Up @@ -239,9 +243,6 @@ def create_local_debug_executor(graph_json_str, mod, device, dump_root=None):
)


RPC_SESSION = None


@register_func("tvm.micro.compile_and_create_micro_session")
def compile_and_create_micro_session(
mod_src_bytes: bytes,
Expand All @@ -264,7 +265,6 @@ def compile_and_create_micro_session(
project_options: dict
Options for the microTVM API Server contained in template_project_dir.
"""
global RPC_SESSION

temp_dir = utils.tempdir()
# Keep temp directory for generate project
Expand All @@ -277,7 +277,7 @@ def compile_and_create_micro_session(
template_project = project.TemplateProject.from_directory(template_project_dir)
generated_project = template_project.generate_project_from_mlf(
model_library_format_path,
temp_dir / "generated-project",
str(temp_dir / "generated-project"),
options=json.loads(project_options),
)
except Exception as exception:
Expand All @@ -288,20 +288,7 @@ def compile_and_create_micro_session(
generated_project.flash()
transport = generated_project.transport()

RPC_SESSION = Session(transport_context_manager=transport)
RPC_SESSION.__enter__()
return RPC_SESSION._rpc._sess


@register_func
def destroy_micro_session():
"""Destroy RPC session for microTVM autotune."""
global RPC_SESSION

if RPC_SESSION is not None:
exc_type, exc_value, traceback = RPC_SESSION.__exit__(None, None, None)
RPC_SESSION = None
if (exc_type, exc_value, traceback) != (None, None, None):
exc = exc_type(exc_value) # See PEP 3109
exc.__traceback__ = traceback
raise exc
rpc_session = Session(transport_context_manager=transport)
# RPC exit is called by shutdown function.
rpc_session.__enter__()
return rpc_session._rpc._sess
2 changes: 1 addition & 1 deletion src/runtime/micro/micro_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ TVM_REGISTER_GLOBAL("micro._rpc_connect").set_body([](TVMArgs args, TVMRetValue*
throw std::runtime_error(ss.str());
}
std::unique_ptr<RPCChannel> channel(micro_channel);
auto ep = RPCEndpoint::Create(std::move(channel), args[0], "");
auto ep = RPCEndpoint::Create(std::move(channel), args[0], "", args[6]);
auto sess = CreateClientSession(ep);
*rv = CreateRPCSessionModule(sess);
});
Expand Down
5 changes: 4 additions & 1 deletion src/runtime/rpc/rpc_endpoint.cc
Original file line number Diff line number Diff line change
Expand Up @@ -691,11 +691,13 @@ void RPCEndpoint::Init() {
* the key to modify their behavior.
*/
std::shared_ptr<RPCEndpoint> RPCEndpoint::Create(std::unique_ptr<RPCChannel> channel,
std::string name, std::string remote_key) {
std::string name, std::string remote_key,
TypedPackedFunc<void()> fshutdown) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tqchen this is needed when a Python-side session constructor may keep a reference to the underlying session. Does it make sense to you?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TQ requests change to fcleanup

std::shared_ptr<RPCEndpoint> endpt = std::make_shared<RPCEndpoint>();
endpt->channel_ = std::move(channel);
endpt->name_ = std::move(name);
endpt->remote_key_ = std::move(remote_key);
endpt->fshutdown_ = fshutdown;
endpt->Init();
return endpt;
}
Expand Down Expand Up @@ -734,6 +736,7 @@ void RPCEndpoint::ServerLoop() {
(*f)();
}
channel_.reset(nullptr);
if (fshutdown_ != nullptr) fshutdown_();
}

int RPCEndpoint::ServerAsyncIOEventHandler(const std::string& in_bytes, int event_flag) {
Expand Down
6 changes: 5 additions & 1 deletion src/runtime/rpc/rpc_endpoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,11 +161,13 @@ class RPCEndpoint {
* \param channel The communication channel.
* \param name The local name of the session, used for debug
* \param remote_key The remote key of the session
* \param fshutdown The shutdown Packed function
* if remote_key equals "%toinit", we need to re-intialize
* it by event handler.
*/
static std::shared_ptr<RPCEndpoint> Create(std::unique_ptr<RPCChannel> channel, std::string name,
std::string remote_key);
std::string remote_key,
TypedPackedFunc<void()> fshutdown = nullptr);

private:
class EventHandler;
Expand All @@ -190,6 +192,8 @@ class RPCEndpoint {
std::string name_;
// The remote key
std::string remote_key_;
// The shutdown Packed Function
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: say something more like: "invoked when the RPC session is terminated"

TypedPackedFunc<void()> fshutdown_;
};

/*!
Expand Down
4 changes: 3 additions & 1 deletion tests/micro/zephyr/test_zephyr.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def test_autotune_conv2d(temp_dir, board, west_cmd, tvm_debug):
do_fork=True,
build_func=tvm.micro.autotvm_build_func,
)
runner = tvm.autotvm.LocalRunner(number=1, repeat=1, timeout=0, module_loader=module_loader)
runner = tvm.autotvm.LocalRunner(number=1, repeat=1, timeout=100, module_loader=module_loader)

measure_option = tvm.autotvm.measure_option(builder=builder, runner=runner)

Expand All @@ -457,6 +457,8 @@ def test_autotune_conv2d(temp_dir, board, west_cmd, tvm_debug):
si_prefix="M",
)

assert tuner.best_flops > 0

# Build without tuning
with pass_context:
lowered = tvm.relay.build(mod, target=target, params=params)
Expand Down
6 changes: 4 additions & 2 deletions tests/python/unittest/test_crt.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def test_autotune():
inputs = {"data": input_data}

target = tvm.target.target.micro("host")
template_project_dir = os.path.join(tvm.micro.get_standalone_crt_dir(), "template", "host")
template_project_dir = pathlib.Path(tvm.micro.get_standalone_crt_dir()) / "template" / "host"

pass_context = tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True})
with pass_context:
Expand All @@ -265,7 +265,7 @@ def test_autotune():
do_fork=True,
build_func=tvm.micro.autotvm_build_func,
)
runner = tvm.autotvm.LocalRunner(number=1, repeat=1, timeout=0, module_loader=module_loader)
runner = tvm.autotvm.LocalRunner(number=1, repeat=1, module_loader=module_loader)

measure_option = tvm.autotvm.measure_option(builder=builder, runner=runner)

Expand All @@ -286,6 +286,8 @@ def test_autotune():
si_prefix="M",
)

assert tuner.best_flops > 0

# Build without tuning
with pass_context:
lowered = tvm.relay.build(mod, target=TARGET, params=params)
Expand Down
8 changes: 4 additions & 4 deletions tutorials/micro/micro_autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@
do_fork=True,
build_func=tvm.micro.autotvm_build_func,
)
runner = tvm.autotvm.LocalRunner(number=1, repeat=1, timeout=0, module_loader=module_loader)
runner = tvm.autotvm.LocalRunner(number=1, repeat=1, timeout=100, module_loader=module_loader)

measure_option = tvm.autotvm.measure_option(builder=builder, runner=runner)

Expand All @@ -146,7 +146,7 @@
# do_fork=False,
# build_func=tvm.micro.autotvm_build_func,
# )
# runner = tvm.autotvm.LocalRunner(number=1, repeat=1, timeout=0, module_loader=module_loader)
# runner = tvm.autotvm.LocalRunner(number=1, repeat=1, timeout=100, module_loader=module_loader)

# measure_option = tvm.autotvm.measure_option(builder=builder, runner=runner)

Expand All @@ -162,7 +162,7 @@
n_trial=num_trials,
measure_option=measure_option,
callbacks=[
tvm.autotvm.callback.log_to_file("microtvm_autotune.log"),
tvm.autotvm.callback.log_to_file("microtvm_autotune.log.txt"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the double extension here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to be consistent with cleanup step in doc stage:

rm -rf /tmp/$$.log.txt

tvm.autotvm.callback.progress_bar(num_trials, si_prefix="M"),
],
si_prefix="M",
Expand Down Expand Up @@ -214,7 +214,7 @@
##########################
# Once autotuning completes, you can time execution of the entire program using the Debug Runtime:

with tvm.autotvm.apply_history_best("microtvm_autotune.log"):
with tvm.autotvm.apply_history_best("microtvm_autotune.log.txt"):
with pass_context:
lowered_tuned = tvm.relay.build(relay_mod, target=TARGET, params=params)

Expand Down