Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions python/tvm/contrib/target/android_nnapi/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""BYOC External Compiler Implementation for Android NNAPI target."""
import tvm
from .compiler import Compiler


def _get_c_type(tipe):
"""Get matching C type for Relay types."""
dtype = str(tipe.dtype)
if dtype == "float32":
return "float"
if dtype == "float16":
return "uint16_t"
if dtype == "int32":
return "int32_t"
assert dtype == "int64", f"{dtype} is unsupported"
return "int64_t"


@tvm.register_func("relay.ext.android_nnapi")
def _codegen(func):
"""Codegen Relay IR to Android NNAPI.

Parameters
----------
func: tvm.relay.Function
The Relay IR function to be codegened.

Returns
-------
mod: runtime.CSourceModule
The resulting Android NNAPI in C++ source code.

Notes
-----
Certain function attributes should be configured:

* func.attrs.NnapiTargetVersion: (int) The targeting API level of Android.
"""
assert isinstance(func, tvm.relay.Function), "Only Function can be codegened to Android NNAPI"
code = """#include <cstdlib>
#include <cstring>
#include <cstdint>
#include <vector>
#include <fcntl.h>
#include <unistd.h>
#include <sys/mman.h>
#include <android/NeuralNetworks.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/packed_func.h>
#include <dlpack/dlpack.h>

namespace {
"""

sid = str(func.attrs.global_symbol)
class_name = sid + "_class"
options = {
"class": {
"self": {
"name": class_name,
},
},
"target": {
"api_level": int(func.attrs.NnapiTargetVersion),
},
}
code += Compiler(options).codegen(func)
code += "\n"

instance_name = sid + "_model"
code += f" {class_name} {instance_name};\n"

sid_impl_name = sid + "_"
code += f" void {sid_impl_name}"
code += "(::tvm::runtime::TVMArgs args, ::tvm::runtime::TVMRetValue *rv) {\n"
code += f" CHECK_EQ(args.num_args, {len(func.params) + 1})"
code += f'<< "num_args is expected to be {len(func.params) + 1}";\n'
code += f" {instance_name}.execute(\n"
for i, p in enumerate(func.params):
assert isinstance(
p.checked_type, tvm.relay.TensorType
), "Function parameter is expected to be a tensor"
code += f" reinterpret_cast< {_get_c_type(p.checked_type)}* >"
code += f"(args[{i}].operator DLTensor*()->data), \n"
assert isinstance(
func.body.checked_type, tvm.relay.TensorType
), "Function output is expected to be a tensor"
code += f" reinterpret_cast< {_get_c_type(func.body.checked_type)}* >"
code += f"(args[{len(func.params)}].operator DLTensor*()->data)\n"
code += f" );\n"
code += " *rv = 0;\n"
code += f" }} // {sid_impl_name}\n"
code += "} // anonymous namespace\n"
code += f"TVM_DLL_EXPORT_PACKED_FUNC({sid}, {sid_impl_name});\n"

return tvm.get_global_func("runtime.CSourceModuleCreate")(code, "c", [sid], [])
18 changes: 18 additions & 0 deletions python/tvm/contrib/target/android_nnapi/_export_object/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Internal namespaces of ExportObject."""
from .json_analyzer import JSONAnalyzer
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Namespace for methods that analyzes the exported JSON."""


class JSONAnalyzer:
"""Analyzing methods of the JSON format of Android NNAPI model."""

class _Operand:
"""Android NNAPI Operand-related analyzing methods on the exported JSON."""

def __init__(self, export_json):
self._export_json = export_json

def get_dtype(self, idx):
"""Get operand dtype.

Parameters
----------
idx: int
operand to be queried.

Returns
-------
dtype: str
dtype of the queried operand.
"""
return self._export_json["types"][self._export_json["operands"][idx]["type"]]["type"]

def get_shape(self, idx):
"""Get operand shape.

Parameters
----------
idx: int
operand to be queried.

Returns
-------
shape: tuple of int or None
shape of the queried operand. None if operand has no shape.
"""
return self._export_json["types"][self._export_json["operands"][idx]["type"]].get(
"shape", None
)

def get_rank(self, idx):
"""Get operand rank.

Parameters
----------
idx: int
operand to be queried.

Returns
-------
rank: int
rank of the queried operand.
"""
shape = self.get_shape(idx)
if shape is None:
return 0
return len(shape)

def get_value(self, idx):
"""Get operand value.

Parameters
----------
idx: int
operand to be queried.

Returns
-------
value:
value of the queried operand. None if there's no value.
"""
value_dict = self._export_json["operands"][idx].get("value", None)
if value_dict is None:
return None

if value_dict["type"] == "constant_idx":
return self._export_json["constants"][value_dict["value"]]["value"]
assert value_dict["type"] == "memory_ptr"
return value_dict["value"]

def get_constant(self, idx):
"""Get operand constant.

Parameters
----------
idx: int
operand to be queried.

Returns
-------
obj: dict
constant dict of the queried operand. None if there's no value.
"""
value_dict = self._export_json["operands"][idx].get("value", None)
if value_dict is None or value_dict["type"] != "constant_idx":
return None
return self._export_json["constants"][value_dict["value"]]

def is_fuse_code(self, idx):
"""Check whether the operand pointed by idx is a FuseCode

Parameters
----------
idx: int
the index of the queried operand.

Returns
-------
b: bool
the queried operand is a FuseCode or not.
"""
dtype = self.get_dtype(idx)
if dtype != "INT32":
return False
shape = self.get_shape(idx)
if shape is not None:
return False
value = self.get_value(idx)
return value in {
"ANEURALNETWORKS_FUSED_NONE",
"ANEURALNETWORKS_FUSED_RELU",
"ANEURALNETWORKS_FUSED_RELU1",
"ANEURALNETWORKS_FUSED_RELU6",
}

def __init__(self, export_json):
self.operand = JSONAnalyzer._Operand(export_json)
101 changes: 101 additions & 0 deletions python/tvm/contrib/target/android_nnapi/compiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Compile a Relay IR Function into Android NNAPI C++ class."""
import copy
import tvm
from . import transform
from . import json_to_nnapi
from .function_to_json_compiler import FunctionToJsonCompiler


class Compiler:
"""Compile a Relay IR Function into Android NNAPI C++ class.

Parameters
----------
options: dict
The compiler option dict. See below for available options.

options["class"]["self"]["name"]: str
The name of the C++ class wrapping the Android NNAPI model. Defaults to "AnnGraph".

options["target"]["api_level"]: int
The targeting Android API level. Defaults to 29.
"""

DEFAULT_OPTIONS = {
"class": {
"self": {
"name": "AnnGraph",
},
},
"target": {
"api_level": 29,
},
}

def __init__(self, options):
self._options = self._expand_options(options)

def codegen(self, func):
"""Compile a Relay IR Function into Android NNAPI C++ class source code

Parameters
----------
func: tvm.relay.Function
The Relay IR Function to be compiled

Returns
-------
code: str
The C++ class source code describing func in Android NNAPI
"""
assert isinstance(func, tvm.relay.Function)
func = transform.FixIllegalPatternForNnapi()(func)

mod = tvm.IRModule({"main": func})
export_obj = FunctionToJsonCompiler(self._options)(mod["main"])

ret = json_to_nnapi.codegen(
export_json=export_obj.asjson(),
options={
"class": {
"name": self._options["class"]["self"]["name"],
},
},
)
return ret

@classmethod
def _expand_options(cls, options):
ret = copy.deepcopy(options)

def _recursive_merge(cur_opts, def_opts):
for k, v in def_opts.items():
if k in cur_opts:
if isinstance(v, dict):
assert isinstance(cur_opts[k], dict)
_recursive_merge(cur_opts[k], v)
else:
# type(cur_opts[k]) should be a basic type
assert isinstance(cur_opts[k], (float, int, str))
else: # option k does not exist in current options, so copy from default options
cur_opts[k] = copy.deepcopy(v)

_recursive_merge(ret, cls.DEFAULT_OPTIONS)

return ret
Loading