Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 74 additions & 16 deletions python/tvm/relay/backend/contrib/ethosu/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
# under the License.
"""Codegen for Arm(R) Ethos(TM)-U NPU"""
from collections import defaultdict

from typing import List, Callable

from ethosu.vela import api as vapi
import tvm
from tvm import relay
from tvm.relay.backend.contrib.ethosu.tir.compiler import LowerToTIR
Expand All @@ -30,7 +31,7 @@
extract_memory_info,
)
from tvm.relay.backend.contrib.ethosu.legalize import LegalizeEthosU
from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator, util
from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator, util, vela_api
from tvm.relay.expr_functor import ExprMutator, ExprVisitor

# pylint: disable=unused-import
Expand Down Expand Up @@ -143,20 +144,25 @@ def __call__(self, *args, **kwargs):


class AnalyzeConsumers(ExprVisitor):
"""Traverses the graph to determine consumers that are NPU operations. The
result is maintained in `npu_consumers`.
"""Traverses the graph to determine consumers that are NPU operations and
which have restrictions to use NHCWB16 layout. The result is maintained in
`npu_consumers` and `restrictions`.

Attributes
----------
npu_consumers : Dict[tvm.relay.expr.Call, List[bool]]
Mapping from NPU operation to list of boolean values that represent
whether or not each consumer is an NPU operation.
restrictions : Dict[tvm.relay.expr.Call, List[bool]]
Mapping from NPU operation to list of boolean values that represent
whether or not operation has restrictions to use NHCWB16 layout.
optimize_ops : Dict[str, Callable]
A map from NPU operation name to function that creates NPU operation.
"""

def __init__(self, optimize_ops):
self.npu_consumers = defaultdict(list)
self.restrictions = defaultdict(list)
self.optimize_ops = optimize_ops
super().__init__()

Expand All @@ -174,6 +180,18 @@ def visit_call(self, call: relay.Call):
for arg in args:
if isinstance(arg, relay.Call) and arg.op.name in self.optimize_ops:
self.npu_consumers[arg].append(is_npu_consumer)
# ReduceSum requires NHWC input in case input tensor has type int32 or
# accelerator is Ethos_U65_512
# https://review.mlplatform.org/plugins/gitiles/ml/ethos-u/ethos-u-vela/+/refs/tags/3.7.0/ethosu/vela/graph_optimiser_util.py#126
has_restrictions = (
call.op.name == "contrib.ethosu.pooling"
and call.attrs["pooling_type"] == "SUM"
and (
arg.checked_type.dtype == "int32"
or vela_api.get_accelerator_config() == vapi.NpuAccelerator.Ethos_U65_512
)
)
self.restrictions[arg].append(has_restrictions)

super().visit_call(call)

Expand All @@ -185,11 +203,11 @@ class LayoutOptimization(ExprMutator):
operation depends on the following:

Check alter input layout: For each argument, if the producer is also an NPU operation and
its output is altered to brick format, then the input layout with respect to the current
argument is altered to brick format.
its output is altered to brick format and there are no restrictions, then the input layout
with respect to the current argument is altered to brick format.

Check alter output layout: If all consumers (child nodes) are an NPU operation, then the
output layout is altered to brick format.
Check alter output layout: If all consumers (child nodes) are an NPU operation and
there are no restrictions, then the output layout is altered to brick format.

Note
----
Expand All @@ -198,15 +216,19 @@ class LayoutOptimization(ExprMutator):

Attributes
----------
npu_consumers : Dict[tvm.relay.expr.Call, bool]
npu_consumers : Dict[tvm.relay.expr.Call, List[bool]]
A map from current call to a list boolean values that state whether or not each consumer
is an NPU operation.
restrictions : Dict[tvm.relay.expr.Call, List[bool]]
A map from current call to a list boolean values that state
whether or not operation has restrictions to use NHCWB16 layout.
optimize_ops : Dict[str, Callable]
A map from NPU operation name to function that creates NPU operation.
"""

def __init__(self, npu_consumers, optimize_ops):
def __init__(self, npu_consumers, restrictions, optimize_ops):
self.npu_consumers = npu_consumers
self.restrictions = restrictions
self.optimize_ops = optimize_ops
super().__init__()

Expand All @@ -224,6 +246,39 @@ def alter_ethosu_op_layout(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Ca
new_call : tvm.relay.expr.Call
New call with altered layouts.
"""

def are_all_consumers_npu(call):
"""
Check whether or not each consumer is an NPU operation.
Parameters
----------
call : tvm.relay.expr.Call
The call pointing to an NPU operation.

Returns
-------
all_consumers_npu : bool
Whether each consumer is an NPU operation.
"""
consumers = self.npu_consumers[call]
return consumers and all(consumers)

def check_restrictions(call):
"""
Check if there are any restrictions for call to use NHCWB16 layout.
Parameters
----------
call : tvm.relay.expr.Call
The call pointing to an NPU operation.

Returns
-------
any_restrictions : bool
Whether there are restrictions.
"""
restrictions = self.restrictions[call]
return restrictions and any(restrictions)

assert isinstance(call.attrs, tvm.ir.Attrs), (
f"The attributes for operator '{call.op.name}' could not be "
"found. Did you register the relay.attrs.Ethosu<opname>Attrs "
Expand All @@ -238,15 +293,16 @@ def alter_ethosu_op_layout(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Ca
input_count += 1
if arg not in self.npu_consumers:
continue
consumers = self.npu_consumers[arg]
parent_has_brick_output = consumers and all(consumers)
if parent_has_brick_output:
parent_has_brick_output = are_all_consumers_npu(arg)
parent_has_restrictions = check_restrictions(arg)
if parent_has_brick_output and not parent_has_restrictions:
layout_string = "ifm_layout" if input_count <= 1 else f"ifm{input_count}_layout"
new_attrs[layout_string] = "NHCWB16"

# Check if we can rewrite the output layouts
consumers = self.npu_consumers[call]
if consumers and all(consumers):
has_brick_output = are_all_consumers_npu(call)
has_restrictions = check_restrictions(call)
if has_brick_output and not has_restrictions:
new_attrs["ofm_layout"] = "NHCWB16"

name = call.op.name
Expand Down Expand Up @@ -293,7 +349,9 @@ def transform_npu_function(self, _, func: relay.Function) -> relay.Function:

analyze = AnalyzeConsumers(optimize_ops)
analyze.visit(func)
return LayoutOptimization(analyze.npu_consumers, optimize_ops).visit(func)
return LayoutOptimization(analyze.npu_consumers, analyze.restrictions, optimize_ops).visit(
func
)

def __call__(self, *args, **kwargs):
pass
Expand Down
18 changes: 12 additions & 6 deletions tests/python/contrib/test_ethosu/infra.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,18 +702,24 @@ def make_ethosu_binary_elementwise(
rescale_scale: int = 0,
rescale_shift: int = 0,
lut=relay.const([], dtype="int8"),
ifm_scale: float = 1.0,
ifm_zero_point: int = 0,
ifm2_scale: float = 1.0,
ifm2_zero_point: int = 0,
ofm_scale: float = 1.0,
ofm_zero_point: int = 0,
):
ethosu_binary_elementwise = ethosu_ops.ethosu_binary_elementwise(
ifm=ifm,
ifm2=ifm2,
lut=lut,
operator_type=operator_type,
ifm_scale=1,
ifm_zero_point=0,
ifm2_scale=1,
ifm2_zero_point=0,
ofm_scale=1,
ofm_zero_point=0,
ifm_scale=ifm_scale,
ifm_zero_point=ifm_zero_point,
ifm2_scale=ifm2_scale,
ifm2_zero_point=ifm2_zero_point,
ofm_scale=ofm_scale,
ofm_zero_point=ofm_zero_point,
ifm_channels=ifm_channels,
ifm2_channels=ifm2_channels,
reversed_operands=reversed_operands,
Expand Down
56 changes: 56 additions & 0 deletions tests/python/contrib/test_ethosu/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,62 @@ def sum_func(x):
)


# Case to check reduce_sum operation with different input types.
@pytest.mark.parametrize("dtype", ["int8", "int32"])
def test_add_reduce_sum(dtype):
ifm_shape = (1, 2, 2, 4)
accel_type = "ethos-u55-256"
np.random.seed(0)

def create_model():
ifm = relay.var("ifm", shape=ifm_shape, dtype=dtype)
ifm2 = relay.var("ifm2", shape=ifm_shape, dtype=dtype)
ifm_scale = 0.0 if dtype == "int32" else 1.0
op = infra.make_ethosu_binary_elementwise(
ifm,
ifm2,
ifm_shape[3],
ifm_shape[3],
"ADD",
dtype,
ifm_scale=ifm_scale,
ifm2_scale=ifm_scale,
)
op = infra.make_ethosu_pooling(
ifm=op,
pooling_type="SUM",
pool_shape=(1, 1),
ofm_channels=1,
strides=(1, 1),
padding=(0, 0, 0, 0),
rounding_mode="NATURAL",
)
return tvm.IRModule.from_expr(relay.Function([ifm, ifm2], op))

def generate_output_data(input_data):
lhs = input_data["ifm"]
rhs = input_data["ifm2"]
# reduce_sum output type is int32.
output_dtype = "int32"
add = lhs + rhs
return [np.sum(add, axis=3).astype(output_dtype)]

cpu_mod = create_model()

# Generate reference data
in_min, in_max = -10, 19
lhs = np.random.randint(in_min, in_max, size=ifm_shape, dtype=dtype)
rhs = np.random.randint(in_min, in_max, size=ifm_shape, dtype=dtype)
input_data = {
"ifm": lhs,
"ifm2": rhs,
}
output_data = {"output": generate_output_data(input_data)[0]}
ethosu_mod = infra.create_ethosu_partition(cpu_mod)

infra.compare_ethosu_with_reference(ethosu_mod, input_data, output_data, accel_type)


@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@pytest.mark.parametrize("dtype", ["int8", "uint8"])
@pytest.mark.parametrize("constant", [np.ones((1, 1, 1, 1)), np.array(1)])
Expand Down
38 changes: 38 additions & 0 deletions tests/python/contrib/test_ethosu/test_layout_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,44 @@ def get_graph():
_assert_structural_equal(a, b)


@pytest.mark.parametrize("dtype", ["int8", "int32"])
def test_add_reduce_sum(dtype):
"""Test add with reduce sum to make sure the layouts remain
unaltered for int32 and altered for other types.
"""

def get_graph(get_expected=False):
in_1 = relay.var("x", shape=(1, 2, 2, 2), dtype=dtype)
in_2 = relay.var("y", shape=(1, 2, 2, 2), dtype=dtype)
layout = "NHCWB16" if get_expected and dtype != "int32" else "NHWC"
add = infra.make_ethosu_binary_elementwise(
in_1,
in_2,
ifm_channels=2,
ifm2_channels=2,
operator_type="ADD",
ofm_dtype=dtype,
ifm_layout="NHWC",
ifm2_layout="NHWC",
ofm_layout=layout,
)
x = infra.make_ethosu_pooling(
ifm=add,
pooling_type="SUM",
pool_shape=(1, 1),
ofm_channels=1,
strides=(1, 1),
padding=(0, 0),
ifm_layout=layout,
ofm_layout="NHWC",
)
return relay.Function(relay.analysis.free_vars(x), x)

a = _optimize(get_graph())
b = _optimize(get_graph(get_expected=True), optimize=False)
_assert_structural_equal(a, b)


def test_multiple_convolution():
"""Test layout optimization pass on linear chain of convolutions. I.e,

Expand Down