Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions python/tvm/contrib/hexagon/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,7 @@ def get_aot_executor(self, module_name: Union[str, pathlib.Path], session: Sessi
aot_module : AotModule
Runtime AOT module that can be used to execute.
"""
aot_mod = self.load_module(module_name, session)
return tvm.runtime.executor.AotModule(aot_mod["default"](session.device))
return session.get_aot_executor(module_name)


class HexagonLauncherAndroid(HexagonLauncherRPC):
Expand Down
176 changes: 171 additions & 5 deletions python/tvm/contrib/hexagon/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@

import tvm
from tvm import rpc as _rpc
import tvm.contrib.hexagon as hexagon
from tvm.relay.backend.executor_factory import (
ExecutorFactoryModule,
AOTExecutorFactoryModule,
GraphExecutorFactoryModule,
)


class Session:
Expand Down Expand Up @@ -101,6 +107,9 @@ def upload(self, local_path: Union[str, pathlib.Path], remote_filename: str):
def load_module(self, module: Union[str, pathlib.Path, tvm.runtime.Module]):
"""Load TVM module.

The session must be established (via __enter__) prior to
calling this function.

Parameters
----------
module : Union[str, pathlib.Path, tvm.runtime.Module]
Expand All @@ -115,16 +124,16 @@ def load_module(self, module: Union[str, pathlib.Path, tvm.runtime.Module]):
the file must already have been uploaded to the remote,
and be placed in the remote workspace.

session : Session

Remote session. The session must be established (via __enter__)
prior to calling this function.

Returns
-------
TVMModule :
TVM module object.
"""

assert (
self.device is not None
), "Hexagon session must be started using __enter__ prior to use"

if isinstance(module, tvm.runtime.Module):
with tempfile.TemporaryDirectory() as temp_dir:
temp_dir = pathlib.Path(temp_dir)
Expand All @@ -136,3 +145,160 @@ def load_module(self, module: Union[str, pathlib.Path, tvm.runtime.Module]):

assert isinstance(module, (str, pathlib.Path)), "Invalid path type:" + str(type(module))
return self._rpc.get_function("tvm.hexagon.load_module")(str(module))

def get_graph_executor(
self,
graph_json: str,
module_name: Union[str, pathlib.Path],
):
"""Create a local GraphModule which consumes a remote libmod.

The session must be established (via __enter__) prior to
calling this function.

Parameters
----------

module_name : Union[str, pathlib.Path]

The remote module filename, following the same restrictions
as `load_module`.

graph_json : str

The string with the graph JSON.

Returns
-------
GraphModule :
Runtime graph module that can be used to execute the graph.

"""

graph_mod = self.load_module(module_name)
return tvm.contrib.graph_executor.create(graph_json, graph_mod, self.device)

def get_aot_executor(
self,
module_name: Union[str, pathlib.Path],
):
"""Create a local GraphModule which consumes a remote libmod.

The session must be established (via __enter__) prior to
calling this function.

Parameters
----------

module_name : Union[str, pathlib.Path]

The remote module filename, following the same restrictions
as `load_module`.

Returns
-------
GraphModule :
Runtime graph module that can be used to execute the graph.

"""

aot_mod = self.load_module(module_name)
return tvm.runtime.executor.AotModule(aot_mod["default"](self.device))

def get_executor_from_factory(self, module: ExecutorFactoryModule):
"""Create a local GraphModule which consumes a remote libmod.

Parameters
----------

module : ExecutorFactoryModule

The module to upload to the remote
session and load.
"""
if isinstance(module, AOTExecutorFactoryModule):
return self._aot_executor_from_factory(module)
if isinstance(module, GraphExecutorFactoryModule):
return self._graph_executor_from_factory(module)

raise TypeError(f"Unsupported executor type: {type(module)}")

def _graph_executor_from_factory(
self,
module: Union[str, pathlib.Path, GraphExecutorFactoryModule],
):
"""Create a local GraphModule which consumes a remote libmod.

The session must be established (via __enter__) prior to
calling this function.

Parameters
----------

module : GraphExecutorFactoryModule

The graph executor module to upload to the remote and load.
This will typically be the output of `tvm.relay.build`,
when passing `executor=Executor("graph")`.

Returns
-------
GraphModule :
Runtime graph module that can be used to execute the graph.

"""

graph_json = module.get_graph_json()
graph_mod = self.load_module(module.get_lib())

return tvm.contrib.graph_executor.create(graph_json, graph_mod, self.device)

def _aot_executor_from_factory(
self,
module: Union[str, pathlib.Path, AOTExecutorFactoryModule],
):
"""Create a local GraphModule which consumes a remote libmod.

The session must be established (via __enter__) prior to
calling this function.

Parameters
----------

module : AOTExecutorFactoryModule

The graph executor module to upload to the remote and load.
This will typically be the output of `tvm.relay.build`,
when passing `executor=Executor("aot")`.

Returns
-------
GraphModule :
Runtime graph module that can be used to execute the graph.

"""

hexagon_arch = set(
target.mcpu.replace("hexagon", "")
for target in module.target.values()
if "hexagon" in target.keys
)
assert hexagon_arch, "No hexagon target architecture found"
assert len(hexagon_arch) == 1, f"Inconsistent hexagon architecture found, {hexagon_arch}"
hexagon_arch = hexagon_arch.pop()

with tempfile.TemporaryDirectory() as temp_dir:
temp_dir = pathlib.Path(temp_dir)
binary_name = "test_binary.so"
binary_path = temp_dir / binary_name

module.export_library(
str(binary_path),
fcompile=hexagon.create_aot_shared,
hexagon_arch=hexagon_arch,
)

self.upload(binary_path, binary_name)

aot_mod = self.load_module(binary_name)
return tvm.runtime.executor.AotModule(aot_mod["default"](self.device))
13 changes: 13 additions & 0 deletions python/tvm/contrib/hexagon/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,19 @@ def create_aot_shared(so_name: Union[str, pathlib.Path], files, hexagon_arch: st
+ "HEXAGON_SDK_PATH in your environment."
)

# The AOT C codegen uses TVM runtime functions
# (e.g. TVMBackendAllocWorkspace) directly. On Hexagon these calls
# should be made using functions pointers provided as __TVM*
# variables in the provided context. This workaround allows the
# the TVM runtime symbols to be visible to the compiled shared
# library.
#
# This workaround can be removed when AOT codegen can be done with
# LLVM codegen.
workaround_link_flags = os.environ.get("HEXAGON_SHARED_LINK_FLAGS")
if workaround_link_flags:
options.extend(workaround_link_flags.split())

tvm_dir = pathlib.Path(os.path.dirname(os.path.realpath(__file__))) / ".." / ".." / ".." / ".."
compute_arch = f"compute{hexagon_arch}"
compile_options = [
Expand Down
Loading