Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 125 additions & 11 deletions python/tvm/relay/backend/contrib/ethosu/legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
# under the License.
# pylint: disable=invalid-name, unused-argument, import-outside-toplevel, no-value-for-parameter
"""A set of passes to legalize some of operations for the NPU"""
from typing import List
from typing import List, Type

import numpy as np # type: ignore

import tvm # type: ignore
Expand All @@ -26,6 +27,7 @@
from tvm.relay.dataflow_pattern import wildcard
from tvm.relay.dataflow_pattern import is_op
from tvm.relay.dataflow_pattern import rewrite
from tvm.relay.dataflow_pattern import CallPattern
from tvm.relay.backend.contrib.ethosu import op as ethosu_ops # type: ignore
from tvm.relay.backend.contrib.ethosu.errors import UnsupportedLayout # type: ignore
from tvm.relay.backend.contrib.ethosu import vela_api
Expand Down Expand Up @@ -121,7 +123,7 @@ def __call__(self, *args, **kwargs):
pass


class EthosUConv2DRewriter(DFPatternCallback):
class Conv2DRewriter(DFPatternCallback):
"""Convert conv2d related composite functions into ethosu_conv2d operators"""

def __init__(self):
Expand Down Expand Up @@ -193,22 +195,22 @@ def callback(


@ir.transform.module_pass(opt_level=1)
class LegalizeEthosUConv2D:
"""This is the pass that wraps the EthosUConv2DRewriter"""
class LegalizeConv2D:
"""This is the pass that wraps the Conv2DRewriter"""

def transform_module(
self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
) -> tvm.ir.IRModule:
for global_var, func in mod.functions.items():
func = rewrite(EthosUConv2DRewriter(), func)
func = rewrite(Conv2DRewriter(), func)
mod.update_func(global_var, func)
return mod

def __call__(self, *args, **kwargs):
pass


class EthosuDepthwiseConv2DRewriter(DFPatternCallback):
class DepthwiseConv2DRewriter(DFPatternCallback):
"""Convert ethosu.qnn_depthwise_conv2d composite functions to ethosu_depthwise_conv2d
operators"""

Expand Down Expand Up @@ -286,14 +288,124 @@ def callback(


@ir.transform.module_pass(opt_level=1)
class LegalizeEthosUDepthwiseConv2D:
"""This is the pass that wraps the EthosUDepthwiseConv2DRewriter"""
class LegalizeDepthwiseConv2D:
"""This is the pass that wraps the DepthwiseConv2DRewriter"""

def transform_module(
self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
) -> tvm.ir.IRModule:
for global_var, func in mod.functions.items():
func = rewrite(DepthwiseConv2DRewriter(), func)
mod.update_func(global_var, func)
return mod

def __call__(self, *args, **kwargs):
pass


class PoolingRewriter(DFPatternCallback):
"""Convert ethosu.avgpool2d and ethosu.maxpool2d composite functions to
ethosu_pooling operators"""

def __init__(
self,
params_class: Type,
pattern: CallPattern,
):
super().__init__(require_type=True)
self.params_class = params_class
self.pattern = pattern

def callback(
self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map
) -> tvm.relay.Expr:
params = self.params_class(post.op.body)
params.ifm.tensor = post.args[0]
channels_map = {
"NHWC": 3,
}
if str(params.ofm.layout) not in channels_map.keys():
raise UnsupportedLayout(str(params.ofm.layout))

activation_map = {"clip": "CLIP"}
if params.activation:
activation = activation_map[params.activation.op.name]
clip_min = int(params.activation.attrs.a_min)
clip_max = int(params.activation.attrs.a_max)
else:
activation = "NONE"
clip_min = 0
clip_max = 0

# Activations requiring LUT is currently not supported, so setting it to an empty list
lut = relay.const([], dtype="int8")

return ethosu_ops.ethosu_pooling(
ifm=post.args[0],
lut=lut,
pooling_type=params.pooling_type,
ifm_scale=params.ifm.q_params.scale_f32,
ifm_zero_point=params.ifm.q_params.zero_point,
ofm_scale=params.ofm.q_params.scale_f32,
ofm_zero_point=params.ofm.q_params.zero_point,
pool_shape=params.pool_shape,
ofm_channels=params.ofm.shape[channels_map[str(params.ofm.layout)]],
strides=params.strides,
padding=params.padding,
activation=activation,
clip_min=clip_min,
clip_max=clip_max,
upscale="NONE",
ifm_layout=str(params.ifm.layout),
ofm_layout=str(params.ofm.layout),
)


class MaxPoolingRewriter(PoolingRewriter):
def __init__(self):
super().__init__(
params_class=ethosu_patterns.MaxPool2DParams,
pattern=(
wildcard().has_attr({"Composite": ethosu_patterns.MaxPool2DParams.composite_name})
)(wildcard()),
)


@ir.transform.module_pass(opt_level=1)
class LegalizeMaxPooling:
"""This is the pass that wraps the MaxPoolingRewriter"""

def transform_module(
self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
) -> tvm.ir.IRModule:
for global_var, func in mod.functions.items():
func = rewrite(MaxPoolingRewriter(), func)
mod.update_func(global_var, func)
return mod

def __call__(self, *args, **kwargs):
pass


class AvgPoolingRewriter(PoolingRewriter):
def __init__(self):
super().__init__(
params_class=ethosu_patterns.AvgPool2DParams,
pattern=(
wildcard().has_attr({"Composite": ethosu_patterns.AvgPool2DParams.composite_name})
)(wildcard()),
)


@ir.transform.module_pass(opt_level=1)
class LegalizeAvgPooling:
"""This is the pass that wraps the AvgPoolingRewriter"""

def transform_module(
self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
) -> tvm.ir.IRModule:
for global_var, func in mod.functions.items():
func = rewrite(EthosuDepthwiseConv2DRewriter(), func)
func = rewrite(AvgPoolingRewriter(), func)
mod.update_func(global_var, func)
return mod

Expand All @@ -312,8 +424,10 @@ def transform_module(
self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
) -> tvm.ir.IRModule:
mod = LegalizeSplit()(mod)
mod = LegalizeEthosUConv2D()(mod)
mod = LegalizeEthosUDepthwiseConv2D()(mod)
mod = LegalizeConv2D()(mod)
mod = LegalizeDepthwiseConv2D()(mod)
mod = LegalizeMaxPooling()(mod)
mod = LegalizeAvgPooling()(mod)
return mod

def __call__(self, *args, **kwargs):
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/backend/contrib/ethosu/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@

from .convolution import ethosu_conv2d
from .depthwise import ethosu_depthwise_conv2d
from .pooling import ethosu_pooling
6 changes: 3 additions & 3 deletions python/tvm/relay/backend/contrib/ethosu/op/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def ethosu_conv2d(
ofm_layout: str = "NHWC",
) -> tvm.relay.Call:
"""This is a quantized 2D convolution operation as supported by the
the NPU. It accepts either NHWC or NHCWB16 format
Ethos(TM)-U NPU. It accepts either NHWC or NHCWB16 format
for the input data and OHWI format for the kernel weights.

Reference: https://developer.arm.com/documentation/102420/0200/
Expand All @@ -132,7 +132,7 @@ def ethosu_conv2d(
scale_bias : tvm.relay.Expr
The packed per-channel weight scale and bias tensor.
lut : tvm.relay.Expr
The look-up table values to use if activation = "LUT".
The look-up table of values to use if activation = "LUT".
ifm_scale : float
The quantization scale for the Input Feature Map tensor.
ifm_zero_point : int
Expand All @@ -146,7 +146,7 @@ def ethosu_conv2d(
kernel_shape : tuple of int
The 2 dimensional kernel shape as (kernel_height, kernel_width).
ofm_channels : int
The number of OFM channels.
The number of the Output Feature Map channels.
strides : tuple of int, optional
The 2 dimensional strides as (stride_height, stride_width).
padding : tuple of int, optional
Expand Down
8 changes: 4 additions & 4 deletions python/tvm/relay/backend/contrib/ethosu/op/depthwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ def ethosu_depthwise_conv2d(
ifm_layout: str = "NHWC",
ofm_layout: str = "NHWC",
) -> tvm.relay.Call:
"""This is a quantized 2D depthwise convolution operation as supported
by the NPU. It accepts either NHWC or NHCWB16 format
"""This is a quantized 2D depthwise convolution operation as supported by the
Ethos(TM)-U NPU. It accepts either NHWC or NHCWB16 format
for the input data and OHWI format for the kernel weights.

Reference: https://developer.arm.com/documentation/102420/0200/
Expand All @@ -132,7 +132,7 @@ def ethosu_depthwise_conv2d(
scale_bias : tvm.relay.Expr
The packed per-channel weight scale and bias tensor.
lut : tvm.relay.Expr
The look-up table values to use if activation = "LUT"
The look-up table of values to use if activation = "LUT"
ifm_scale : float
The quantization scale for the Input Feature Map tensor.
ifm_zero_point : int
Expand All @@ -146,7 +146,7 @@ def ethosu_depthwise_conv2d(
kernel_shape : tuple of int
The 2 dimensional kernel shape as (kernel_height, kernel_width).
ofm_channels : int
The number of OFM channels.
The number of the Output Feature Map channels.
strides : tuple of int, optional
The 2 dimensional strides as (stride_height, stride_width).
padding : tuple of int, optional
Expand Down
Loading