Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions python/tvm/relay/op/contrib/ethosu.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=ungrouped-imports
# pylint: disable=ungrouped-imports, import-outside-toplevel
"""Arm(R) Ethos(TM)-U NPU supported operators."""
import functools

Expand All @@ -36,14 +36,6 @@
# rely on imports from ethos-u-vela, we protect them with the decorator @requires_vela
# implemented below.
from ethosu.vela import api as vapi # type: ignore
from tvm.relay.backend.contrib.ethosu import preprocess
from tvm.relay.backend.contrib.ethosu.util import QConv2DArgs # type: ignore
from tvm.relay.backend.contrib.ethosu.util import BiasAddArgs
from tvm.relay.backend.contrib.ethosu.util import RequantArgs
from tvm.relay.backend.contrib.ethosu.util import BinaryElementwiseArgs
from tvm.relay.backend.contrib.ethosu.util import DequantizeArgs
from tvm.relay.backend.contrib.ethosu.util import QuantizeArgs
from tvm.relay.backend.contrib.ethosu.util import get_dim_value
except ImportError:
vapi = None

Expand Down Expand Up @@ -116,6 +108,8 @@ def check_valid_dtypes(tensor_params: List[TensorParams], supported_dtypes: List

def check_weights(weights: TensorParams, dilation: List[int]):
"""This function checks whether weight tensor is compatible with the NPU"""
from tvm.relay.backend.contrib.ethosu.util import get_dim_value

dilated_height_range = (1, 64)
dilated_hxw_range = (1, 64 * 64)
weights_limit = 127 * 65536
Expand Down Expand Up @@ -200,6 +194,10 @@ class QnnConv2DParams:

@requires_vela
def __init__(self, func_body: tvm.relay.Function):
from tvm.relay.backend.contrib.ethosu.util import QConv2DArgs # type: ignore
from tvm.relay.backend.contrib.ethosu.util import BiasAddArgs
from tvm.relay.backend.contrib.ethosu.util import RequantArgs

activation = None
if str(func_body.op) in self.activation_map.keys():
activation = func_body
Expand Down Expand Up @@ -472,6 +470,8 @@ class BinaryElementwiseParams:
"""

def __init__(self, func_body: Call, operator_type: str, has_quantization_parameters: bool):
from tvm.relay.backend.contrib.ethosu.util import BinaryElementwiseArgs

clip = None
if str(func_body.op) == "clip":
clip = func_body
Expand Down Expand Up @@ -869,6 +869,9 @@ class AbsParams:
composite_name = "ethos-u.abs"

def __init__(self, func_body: Call):
from tvm.relay.backend.contrib.ethosu.util import QuantizeArgs
from tvm.relay.backend.contrib.ethosu.util import DequantizeArgs

quantize = func_body
abs_op = quantize.args[0]
dequantize = abs_op.args[0]
Expand Down Expand Up @@ -1037,6 +1040,8 @@ def partition_for_ethosu(
mod : IRModule
The partitioned IRModule with external global functions
"""
from tvm.relay.backend.contrib.ethosu import preprocess

if params:
mod["main"] = bind_params_by_name(mod["main"], params)

Expand Down