diff --git a/python/tvm/relay/backend/contrib/ethosu/te/dma.py b/python/tvm/relay/backend/contrib/ethosu/te/dma.py index f80c7d963088..5d51c7bfae20 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/dma.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/dma.py @@ -146,6 +146,12 @@ def write_compute( def convert_to_nhwc_compute(tensor: te.Tensor, layout: str, channels: int) -> te.Tensor: """Converts a tensor into NHWC layout if it's in NHWCB16 layout. + When the current layout is NHCWB16, a reduce sum operation is inserted + to ensure that the whole of the input tensor has a data dependency on + the copy operation. Without this, TVM removes compute that is deemed to + be unnecessary, which causes strides for the NPU to be calculated + incorrectly. + Parameters ---------- tensor : te.Tensor @@ -167,9 +173,12 @@ def convert_to_nhwc_compute(tensor: te.Tensor, layout: str, channels: int) -> te "layout": layout, } if layout == "NHCWB16": + rc = te.reduce_axis((0, 16), name="rc") return te.compute( (tensor.shape[0], tensor.shape[1], tensor.shape[3], channels), - lambda nn, hh, ww, cc: tensor(nn, hh, te.indexdiv(cc, 16), ww, te.indexmod(cc, 16)), + lambda nn, hh, ww, cc: te.sum( + tensor(nn, hh, te.indexdiv(cc, 16), ww, te.indexmod(rc, 16)), axis=rc + ), name="ethosu_convert_to_nhwc", attrs=convert_to_nhwc_attrs, ) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/dma.py b/python/tvm/relay/backend/contrib/ethosu/tir/dma.py index a116e51c5b7c..46df20814eb5 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/dma.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/dma.py @@ -94,10 +94,18 @@ def get_convert_to_nhwc_params(stmt): The pointer produced by the operation. """ - _, body = get_op_attrs(stmt) + attrs, body = get_op_attrs(stmt) _, _, _, c, _, inner = get_outer_loops(body, "NHWC") + + # Ignore the reduce sum operation inserted to ensure + # compute that is deemed uneccesary isn't removed by TVM. + if attrs["layout"] == "NHCWB16": + inner = inner.body + input_pointer = inner.value.b.buffer_var + else: + input_pointer = inner.value.buffer_var + output_pointer = inner.buffer_var - input_pointer = inner.value.buffer_var return c.extent, input_pointer, output_pointer diff --git a/tests/python/contrib/test_ethosu/test_replace_pooling.py b/tests/python/contrib/test_ethosu/test_replace_pooling.py index 79526ed527e8..ee72ffa4cb99 100644 --- a/tests/python/contrib/test_ethosu/test_replace_pooling.py +++ b/tests/python/contrib/test_ethosu/test_replace_pooling.py @@ -26,52 +26,18 @@ from .infra import make_ethosu_pooling, get_pooling_args -@pytest.mark.parametrize( - "ifm_shape, ofm_channels, ifm_layout, ofm_layout, rounding_mode", - [ - ((1, 5, 9, 3), 3, "NHWC", "NHWC", "TFL"), - ((1, 8, 3, 9, 16), 40, "NHCWB16", "NHCWB16", "NATURAL"), - ((1, 8, 3, 9, 16), 40, "NHCWB16", "NHWC", "TRUNCATE"), - ((1, 8, 9, 40), 40, "NHWC", "NHCWB16", "TFL"), - ], -) -@pytest.mark.parametrize("pooling_type", ["AVG", "MAX"]) -@pytest.mark.parametrize("activation", ["NONE", "CLIP"]) -def test_pooling_single( +def _create_serial_pooling( ifm_shape, ofm_channels, ifm_layout, ofm_layout, + pool_shape, pooling_type, - activation, - rounding_mode, + strides, + padding, + activation="NONE", + rounding_mode="TFL", ): - pool_shape = (3, 2) - strides = (1, 2) - padding = (1, 1, 1, 0) - ifm = relay.var("ifm", shape=ifm_shape, dtype="int8") - pooling = make_ethosu_pooling( - ifm, - pooling_type, - pool_shape, - ofm_channels, - strides, - padding, - activation, - ifm_layout, - ofm_layout, - rounding_mode, - ) - func = relay.Function(relay.analysis.free_vars(pooling), pooling) - func = run_opt_pass(func, relay.transform.InferType()) - mod, _ = lower_to_tir(func) - data = [] - - def _visit(stmt): - if isinstance(stmt, tvm.tir.Call): - data.append(get_pooling_args(stmt)) - - tvm.tir.stmt_functor.post_order_visit(mod["main"].body, _visit) if ifm_layout == "NHWC": ifm_stride_c = 1 ifm_stride_w = ifm_shape[3] @@ -80,7 +46,7 @@ def _visit(stmt): ofm_width = (ifm_shape[2] - pool_shape[1] + padding[1] + padding[1]) // strides[1] + 1 else: ifm_stride_w = 16 - ifm_stride_c = 16 * ifm_shape[3] + ifm_stride_c = 16 * ifm_shape[3] if ofm_channels >= 16 else 1 ifm_stride_h = 16 * ifm_shape[2] * ifm_shape[3] ofm_height = (ifm_shape[1] - pool_shape[0] + padding[0] + padding[0]) // strides[0] + 1 ofm_width = (ifm_shape[3] - pool_shape[1] + padding[1] + padding[1]) // strides[1] + 1 @@ -91,10 +57,10 @@ def _visit(stmt): ofm_stride_h = ofm_channels * ofm_width if ofm_height > 1 else 1 else: ofm_stride_w = 16 - ofm_stride_c = 16 * ofm_width + ofm_stride_c = 16 * ofm_width if ofm_channels >= 16 else 1 ofm_stride_h = 16 * ofm_width * ((ofm_channels - 1) // 16 + 1) - serial_pooling = spec.SerialPooling( + return spec.SerialPooling( ifm=spec.SerialFeatureMap( data_type="int8", height=ifm_shape[1], @@ -154,8 +120,139 @@ def _visit(stmt): upscale="NONE", ) + +@pytest.mark.parametrize( + "ifm_shape, ofm_channels, ifm_layout, ofm_layout, rounding_mode", + [ + ((1, 5, 9, 3), 3, "NHWC", "NHWC", "TFL"), + ((1, 8, 3, 9, 16), 40, "NHCWB16", "NHCWB16", "NATURAL"), + ((1, 8, 3, 9, 16), 40, "NHCWB16", "NHWC", "TRUNCATE"), + ((1, 8, 9, 40), 40, "NHWC", "NHCWB16", "TFL"), + ((1, 8, 9, 8), 8, "NHWC", "NHCWB16", "TFL"), + ], +) +@pytest.mark.parametrize("pooling_type", ["AVG", "MAX"]) +@pytest.mark.parametrize("activation", ["NONE", "CLIP"]) +def test_pooling_single( + ifm_shape, + ofm_channels, + ifm_layout, + ofm_layout, + pooling_type, + activation, + rounding_mode, +): + pool_shape = (3, 2) + strides = (1, 2) + padding = (1, 1, 1, 0) + ifm = relay.var("ifm", shape=ifm_shape, dtype="int8") + pooling = make_ethosu_pooling( + ifm, + pooling_type, + pool_shape, + ofm_channels, + strides, + padding, + activation, + ifm_layout, + ofm_layout, + rounding_mode, + ) + func = relay.Function(relay.analysis.free_vars(pooling), pooling) + func = run_opt_pass(func, relay.transform.InferType()) + mod, _ = lower_to_tir(func) + data = [] + + def _visit(stmt): + if isinstance(stmt, tvm.tir.Call): + data.append(get_pooling_args(stmt)) + + tvm.tir.stmt_functor.post_order_visit(mod["main"].body, _visit) + + serial_pooling = _create_serial_pooling( + ifm_shape, + ofm_channels, + ifm_layout, + ofm_layout, + pool_shape, + pooling_type, + strides, + padding, + activation, + rounding_mode, + ) assert data[0] == ["ethosu_pooling"] + list(serial_pooling) +def test_correct_stride_with_multiple_pooling(): + """Testing a specific case of two pooling operations with NHWC inputs/outputs + but a NHCWB16 intermediate tensor. This lead to elements being accessed in the + wrong order by the NPU, due to incorrect stride values being calculated.""" + + ifm_shape = (1, 4, 4, 8) + ofm_channels = 8 + pooling_type = "MAX" + pool_shape = (1, 1) + strides = (1, 1) + padding = (0, 0, 0, 0) + + ifm = relay.var("ifm", shape=ifm_shape, dtype="int8") + op = make_ethosu_pooling( + ifm, + pooling_type, + pool_shape, + ofm_channels, + strides, + padding, + ifm_layout="NHWC", + ofm_layout="NHCWB16", + ) + op = make_ethosu_pooling( + op, + pooling_type, + pool_shape, + ofm_channels, + strides, + padding, + ifm_layout="NHCWB16", + ofm_layout="NHWC", + ) + func = relay.Function(relay.analysis.free_vars(op), op) + func = run_opt_pass(func, relay.transform.InferType()) + mod, _ = lower_to_tir(func) + + data = [] + + def _visit(stmt): + if isinstance(stmt, tvm.tir.Call): + data.append(get_pooling_args(stmt)) + + tvm.tir.stmt_functor.post_order_visit(mod["main"].body, _visit) + + serial_pooling_1 = _create_serial_pooling( + [1, 4, 4, 8], + 8, + "NHWC", + "NHCWB16", + pool_shape, + pooling_type, + strides, + padding, + ) + serial_pooling_2 = _create_serial_pooling( + [1, 4, 1, 4, 16], + 8, + "NHCWB16", + "NHWC", + pool_shape, + pooling_type, + strides, + padding, + ) + + assert data[0] == ["ethosu_pooling"] + list(serial_pooling_1) + assert data[1] == ["ethosu_pooling"] + list(serial_pooling_2) + + if __name__ == "__main__": pytest.main([__file__])