Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,11 @@ TVM_DLL Pass LowerThreadAllreduce();
*/
TVM_DLL Pass InferFragment();

/*!
* \brief This annotation for nodes to be disabled for builtin lowering
*/
static constexpr const char* kDisableLowerTVMBuiltin = "disable_lower_builtin";

/*!
* \brief Lower builtin intrinsics.
* \return The pass.
Expand Down
96 changes: 62 additions & 34 deletions python/tvm/relay/backend/contrib/ethosu/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Codegen for Arm(R) Ethos(TM)-U"""
"""Codegen for Arm(R) Ethos(TM)-U NPU"""

import tvm
from tvm import relay
from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir
Expand All @@ -24,24 +25,6 @@
from tvm.relay.backend.contrib.ethosu import util


@tvm._ffi.register_func("relay.ext.ethos-u")
def ethosu_compiler(external_function):
"""The entry-point to a compile a external relay function of
NPU compatible operators to generated command stream.
Such generated command stream would be used to create c-source r
runtime module that interfaces with NPU driver.
"""
assert isinstance(external_function, tvm.ir.function.BaseFunc)
func_name = external_function.attrs["global_symbol"]
# There should only be a single input
assert len(external_function.params) == 1
input_size = util.calculate_size_bytes(external_function.params[0])
output_size = util.calculate_size_bytes(external_function.body)
cmms, encoded_constants, scratch_size = _compile(external_function)
ethosu_runtime = tvm._ffi.get_global_func("runtime.module.ethos-u.create")
return ethosu_runtime(func_name, cmms, encoded_constants, scratch_size, input_size, output_size)


@tvm._ffi.register_func("relay.ext.ethos-u.constant_updater")
def constant_updater(expr, symbol): # pylint: disable=unused-argument
"""
Expand All @@ -52,25 +35,25 @@ def constant_updater(expr, symbol): # pylint: disable=unused-argument
return dict()


def _compile(ext_func):
@tvm._ffi.register_func("relay.ext.ethos-u.relay_to_tir_func")
def relay_to_tir_func(ext_func: relay.Function) -> tvm.tir.PrimFunc:
"""
This is the main wrapper that accepts an external
relay function and runs all the passes to lower it down
to command stream
This is the hook for python-based lowering of relay function
that gets offloaded to the microNPU.

Parameters
----------
ext_func : tvm.relay.function.Function
The partitioned relay function
ext_func : relay.Function
This is the partitioned relay function

Returns
-------
cs : str
An hex string of the bytes of command stream
encoded_constants : str
An hex string of the bytes that includes concat'd
encoded weights, encoded biases and scales.
scratch_size : int
The size of the scratch buffer needed.
primfunc : tir.PrimFunc
This returns the scheduled PrimFunc
"""
assert len(ext_func.params) == 1
input_size = util.calculate_size_bytes(ext_func.params[0])
output_size = util.calculate_size_bytes(ext_func.body)
mod = tvm.IRModule()
mod["main"] = ext_func
mod = LegalizeEthosU()(mod)
Expand All @@ -80,5 +63,50 @@ def _compile(ext_func):
# that can perform scheduling based on user inputs such as
# scratch memory size.
tir_mod, params = lower_to_tir(mod["main"], copy_constants())
cmms, encoded_constants, scratch_size = tir_to_cs_translator.translate(tir_mod, params)
return cmms, encoded_constants, scratch_size

for idx in params.keys():
params[idx] = tvm.nd.array(params[idx])

primfunc = tir_mod["main"]
primfunc = primfunc.with_attr("global_symbol", ext_func.attrs["global_symbol"])
primfunc = primfunc.with_attr("ethos-u.constants", params)
primfunc = primfunc.with_attr("ethos-u.input_size", input_size)
primfunc = primfunc.with_attr("ethos-u.output_size", output_size)
return primfunc


@tvm._ffi.register_func("relay.ext.ethos-u.primfunc_to_artifact")
def primfunc_to_artifact(primfunc: tvm.tir.PrimFunc) -> util.CompilationArtifact:
"""
This is the hook for python-based lowering of TIR PrimFunc
that has undergone unified optimization to Compilation
Artifact destined for the microNPU.

Parameters
----------
primfunc : tir.PrimFunc
TIR PrimFunc that has undergone unified optimizations

Returns
-------
CompilationArtifact
This is a structure that holds the binary artifacts
for the microNPU
"""
symbol = str(primfunc.attrs["global_symbol"])
params = primfunc.attrs["ethos-u.constants"]
input_size = primfunc.attrs["ethos-u.input_size"]
output_size = primfunc.attrs["ethos-u.output_size"]
tir_mod = tvm.IRModule()
tir_mod[symbol] = primfunc

params_with_int_keys = dict()
for idx in params.keys():
params_with_int_keys[int(idx)] = params[idx].numpy()

cmms, encoded_constants, scratch_size = tir_to_cs_translator.translate(
tir_mod, params_with_int_keys
)
return util.CompilationArtifact(
cmms, encoded_constants, scratch_size, input_size, output_size, symbol
)
3 changes: 2 additions & 1 deletion python/tvm/relay/backend/contrib/ethosu/tir/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from tvm.relay.expr_functor import ExprMutator
from tvm.driver.build_module import schedule_to_module

from .passes import ReplaceOperators, RemoveZeroStores, EncodeConstants
from .passes import ReplaceOperators, RemoveZeroStores, EncodeConstants, AnnotateAllocates
from .scheduler import schedule


Expand Down Expand Up @@ -88,6 +88,7 @@ def lower_ethosu(sch, args, const_dict, name="main"):
mod, const_dict = EncodeConstants(const_dict)(mod)
mod = tvm.tir.transform.StorageRewrite()(mod)
mod = tvm.tir.transform.RemoveNoOp()(mod)
mod = AnnotateAllocates()(mod)
return mod, const_dict


Expand Down
31 changes: 31 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/tir/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,3 +488,34 @@ def _encode_constants(mod):
return new_func, new_const_dict

return _encode_constants


# This need to be kept in sync with kDisableLowerTVMBuiltin in include/tvm/tir/transform.h
DISABLE_LOWER_BUILTIN = "disable_lower_builtin"


def AnnotateAllocates():
"""
This is pass to annotate all allocate
nodes of the PrimFuncs of the microNPU
to be not lowered to built-ins.
"""

def _post_transform(allocate):
return tvm.tir.Allocate(
buffer_var=allocate.buffer_var,
dtype=allocate.dtype,
extents=allocate.extents,
condition=allocate.condition,
body=allocate.body,
annotations={DISABLE_LOWER_BUILTIN: True},
)

def _ftransform(f, mod, ctx):
return f.with_body(
tvm.tir.stmt_functor.ir_transform(f.body, None, _post_transform, ["tir.Allocate"])
)

return tvm.tir.transform.prim_func_pass(
_ftransform, opt_level=0, name="tir.ethosu.annotate_allocates"
)
22 changes: 11 additions & 11 deletions python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,16 +152,16 @@ def extract_buffer_info(
primfunc = mod.functions.items()[0][1]
for idx, const_data in param_dict.items():
param = primfunc.params[idx]
buffer_info[primfunc.buffer_map[param].data] = BufferInfo(
buffer_info[param] = BufferInfo(
const_data, const_data.shape, const_data.dtype, BufferType.constant
)

for param in primfunc.params:
if primfunc.buffer_map[param].data not in buffer_info.keys():
buffer_info[primfunc.buffer_map[param].data] = BufferInfo(
if param not in buffer_info.keys():
buffer_info[param] = BufferInfo(
None,
None,
None,
primfunc.buffer_map[param].shape,
primfunc.buffer_map[param].dtype,
BufferType.input_or_output,
)

Expand Down Expand Up @@ -223,7 +223,7 @@ def replace_npu_fm_with_address(npu_fm):
def replace_npu_address_range_with_address(npu_addr_range):
assert isinstance(npu_addr_range.address, tvm.tir.Load)
buffer = npu_addr_range.address.buffer_var
assert buffer in buffer_addresses.keys()
assert buffer in buffer_addresses.keys(), f"searching for buffer : {buffer}, but not found"
address, buffer_type = buffer_addresses[buffer]
return vapi.NpuAddressRange(_REGION_MAP[buffer_type], address, npu_addr_range.length)

Expand Down Expand Up @@ -269,17 +269,17 @@ def classify_io(buffer):
size_in_bytes = util.round_up(size_in_bytes, 16)
constant_tensor = np.append(constant_tensor, np.resize(info.values, size_in_bytes))
else:
size_in_bytes = int(
(np.iinfo(np.dtype(info.dtype)).bits // 8) * np.prod(list(info.shape))
)
# Every memory address the NPU access have to be 16 byte aligned
size_in_bytes = util.round_up(size_in_bytes, 16)
if info.btype == BufferType.input_or_output:
buffer_type = classify_io(_buffer)
assert buffer_type in (BufferType.input, BufferType.output)
address = 0
buffer_addresses[_buffer] = (address, buffer_type)
else:
size_in_bytes = int(
(np.iinfo(np.dtype(info.dtype)).bits // 8) * np.prod(list(info.shape))
)
# Every memory address the NPU access have to be 16 byte aligned
size_in_bytes = util.round_up(size_in_bytes, 16)
assert info.btype == BufferType.scratch
address = scratch_size
scratch_size += size_in_bytes
Expand Down
30 changes: 30 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@

import tvm # type: ignore
from tvm import relay
from tvm._ffi import register_object
from tvm.runtime import Object
from . import _ffi_api


class QConv2DArgs(Enum):
Expand Down Expand Up @@ -209,3 +212,30 @@ def calculate_size_bytes(expr):
element_size = type_info.bits // 8
elements = np.prod(list(expr.checked_type.shape))
return element_size * elements


@register_object("relay.ext.ethos-u.CompilationArtifact")
class CompilationArtifact(Object):
"""
This is a structure to hold binary artifacts
for the microNPU.
"""

def __init__(
self,
command_stream: str,
encoded_constants: str,
scratch_size: int,
input_size: int,
output_size: int,
function_name: str,
):
self.__init_handle_by_constructor__(
_ffi_api.CompilationArtifact, # type: ignore # pylint: disable=no-member
command_stream,
encoded_constants,
scratch_size,
input_size,
output_size,
function_name,
)
8 changes: 6 additions & 2 deletions python/tvm/runtime/object_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,13 @@ def convert_to_object(value, span=None):
if isinstance(value, dict):
vlist = []
for item in value.items():
if not isinstance(item[0], ObjectTypes) and not isinstance(item[0], string_types):
if (
not isinstance(item[0], ObjectTypes)
and not isinstance(item[0], string_types)
and not isinstance(item[0], Number)
):
raise ValueError("key of map must already been a container type")
vlist.append(item[0])
vlist.append(convert_to_object(item[0]))
vlist.append(convert_to_object(item[1]))
return _ffi_api.Map(*vlist)
if isinstance(value, ObjectGeneric):
Expand Down
Loading