From b548e0301f31185c3ba750e8c169bfb91a915b9b Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Wed, 29 Sep 2021 08:28:03 +0000 Subject: [PATCH 1/2] [microNPU] Allow constants to be given as input to an operator Currently the expectation is that all constants need to be encoded, however, this is not always the case for scalar inputs. This PR ensures that constants that don't need encoding are not treated like encoded constants by the EncodeConstants pass. Change-Id: I79cf4aa10d01c4ae9ce9cdafb6f21ebb2d028126 --- .../backend/contrib/ethosu/tir/passes.py | 9 +++- .../contrib/test_ethosu/test_codegen.py | 50 +++++++++++++++++++ .../test_ethosu/test_encode_constants.py | 47 ++++++++++++++++- .../contrib/test_ethosu/test_legalize.py | 47 +++++++++++++++++ 4 files changed, 150 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py index 75d75a5935b4..b070b11c0bf5 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py @@ -392,14 +392,19 @@ def _visit_rewrite(stmt): # For extern calls, we need to rewrite pairs of arguments corresponding to # base address load and the length of the load. new_args = [stmt.args[0]] + new_buffers = rewrite_buffer.values() for i in range(1, len(stmt.args)): # If the previous argument was a load, the current should be a length if isinstance(stmt.args[i - 1], tvm.tir.Load): load = stmt.args[i - 1] pointer = load.buffer_var if pointer in pointer_to_buffer: - new_args.append(np.prod(list(pointer_to_buffer[pointer].shape))) - continue + buffer = pointer_to_buffer[pointer] + # Only rewrite the arguments of buffers that have been encoded + if buffer in new_buffers: + new_arg = np.prod(list(pointer_to_buffer[pointer].shape)) + new_args.append(new_arg) + continue new_args.append(stmt.args[i]) return tvm.tir.Call(stmt.dtype, stmt.op, new_args, stmt.span) diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index d62f9d161ad3..455aff0aed33 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -435,6 +435,56 @@ def representative_dataset(): infra.verify_source(compiled_models, accel_type) +@pytest.mark.parametrize("accel_type", ACCEL_TYPES) +def test_binary_add_from_constant_scalar(accel_type): + dtype = "uint8" + ifm_shape = (1, 4, 4, 8) + + def create_relay_graph(): + inp = relay.var("input", shape=ifm_shape, dtype=dtype) + scalar = relay.const(np.ones((1, 1, 1, 1), dtype=dtype), dtype=dtype) + add = relay.qnn.op.add( + inp, + scalar, + relay.const(1.0, dtype="float32"), + relay.const(0, dtype="int32"), + relay.const(1.0, dtype="float32"), + relay.const(0, dtype="int32"), + relay.const(1.0, dtype="float32"), + relay.const(0, dtype="int32"), + ) + func = relay.Function(relay.analysis.free_vars(add), add) + return tvm.IRModule.from_expr(func) + + mod = create_relay_graph() + partitioned_mod = partition_for_ethosu(mod) + + # Generate reference data + input_data = {"input": np.random.randint(low=0, high=255, size=ifm_shape, dtype=dtype)} + output_data = generate_ref_data(mod, input_data) + + compiled_models = infra.build_source( + partitioned_mod, + input_data, + output_data, + accel_type, + output_tolerance=0, + ) + + # Assumes only two runtime.Modules are created -- i.e. single offload module + imported_modules = compiled_models[0].executor_factory.lib.imported_modules + assert len(imported_modules) == 2 + ethosu_module = imported_modules[0] + + # Verify generated C source + get_cs = tvm._ffi.get_global_func("runtime.module.ethosu.getcs") + cmms = get_cs(ethosu_module) + cmms = bytes.fromhex(cmms) + + infra.print_payload(cmms) + infra.verify_source(compiled_models, accel_type) + + @pytest.mark.parametrize("accel_type", ACCEL_TYPES) @pytest.mark.parametrize( "ifm_shape, ifm2_shape", diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py b/tests/python/contrib/test_ethosu/test_encode_constants.py index 5b60102162be..a6bdf9be7765 100644 --- a/tests/python/contrib/test_ethosu/test_encode_constants.py +++ b/tests/python/contrib/test_ethosu/test_encode_constants.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import pytest +import numpy as np pytest.importorskip("ethosu.vela") import tvm @@ -23,8 +24,10 @@ from tvm.relay.testing import run_opt_pass from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir from tvm.relay.backend.contrib.ethosu.tir.scheduler import Convolution2DCompute +from tvm.relay.backend.contrib.ethosu.tir.scheduler import copy_constants +from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator -from .infra import make_ethosu_conv2d +from .infra import make_ethosu_conv2d, make_ethosu_binary_elementwise # fmt: off @@ -270,5 +273,47 @@ def _get_func(): assert reference_const_sizes == test_const_sizes +def test_constant_as_input(): + """Test to check that constants specified as inputs aren't + interpreted as an encoded constant.""" + + def get_graph(): + dtype = "uint8" + ifm = relay.var("ifm", shape=(1, 16, 16, 32), dtype=dtype) + conv1 = make_ethosu_conv2d( + ifm, + 32, + 16, + (1, 1), + (0, 0), + (1, 1), + (1, 1), + ) + scalar = relay.const(np.ones((1, 1, 1, 1), dtype=dtype), dtype=dtype) + add1 = make_ethosu_binary_elementwise( + conv1, scalar, ifm_channels=32, ifm2_channels=1, operator_type="ADD", ofm_dtype=dtype + ) + func = relay.Function(relay.analysis.free_vars(add1), add1) + func = run_opt_pass(func, relay.transform.InferType()) + return func + + tir_mod, params = lower_to_tir(get_graph(), copy_constants()) + + # Check tile address for the scalar constant input hasn't been + # overwritten. + extern_calls = tir_mod["main"].body.body.body.body.body + binary_elmtwise = extern_calls[-1].value + args = binary_elmtwise.args + + reason = "Tile address overwritten" + assert args[26] == 0, reason + assert args[27] == 0, reason + assert args[28] == 0, reason + + # More generally, check compiles successfully to make sure + # nothing else was overrwritten. + tir_to_cs_translator.translate(tir_mod, params) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index 2247234e6d68..5fea513b17f9 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -693,6 +693,53 @@ def verify(ext_func): verify(mod["tvmgen_default_ethos_u_main_0"]) +def test_binary_add_from_constant_scalar(): + dtype = "uint8" + ifm_shape = (1, 4, 4, 8) + + def create_graph(): + inp = relay.var("input", shape=ifm_shape, dtype=dtype) + scalar = relay.const(np.ones((1, 1, 1, 1), dtype=dtype), dtype=dtype) + add = relay.qnn.op.add( + inp, + scalar, + relay.const(1.0, dtype="float32"), + relay.const(0, dtype="int32"), + relay.const(1.0, dtype="float32"), + relay.const(0, dtype="int32"), + relay.const(1.0, dtype="float32"), + relay.const(0, dtype="int32"), + ) + func = relay.Function(relay.analysis.free_vars(add), add) + return tvm.IRModule.from_expr(func) + + def verify(ext_func): + op = ext_func.body + assert list(op.args[0].checked_type.shape) == [1, 4, 4, 8] + assert list(op.args[1].checked_type.shape) == [1, 1, 1, 1] + assert op.args[0].checked_type.dtype == "uint8" + assert list(op.checked_type.shape) == [1, 4, 4, 8] + assert op.checked_type.dtype == "uint8" + assert op.attrs.operator_type == "ADD" + + rewriter = legalize.AddRewriter() + pattern_table = [ + ( + ethosu.AddParams.composite_name, + ethosu.qnn_add_pattern(), + lambda pat: ethosu.AddParams(pat).is_valid(), + ), + ] + + mod = create_graph() + mod = partition_ethosu_by_table(mod, pattern_table) + + mod["tvmgen_default_ethosu_main_0"] = dataflow_pattern.rewrite( + rewriter, mod["tvmgen_default_ethosu_main_0"] + ) + verify(mod["tvmgen_default_ethosu_main_0"]) + + @pytest.mark.parametrize( "ifm_shape, ifm2_shape, reversed_operands", [ From e1458c5905dec25e4194ca3051a833b16c70f2a2 Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Tue, 16 Nov 2021 18:07:59 +0000 Subject: [PATCH 2/2] address comments Change-Id: I67b61a2d2f67de25c47d2ace0e3a22c59ba8ea15 --- tests/python/contrib/test_ethosu/test_codegen.py | 2 +- tests/python/contrib/test_ethosu/test_encode_constants.py | 4 ++-- tests/python/contrib/test_ethosu/test_legalize.py | 6 +++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index 455aff0aed33..93af66da8194 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -477,7 +477,7 @@ def create_relay_graph(): ethosu_module = imported_modules[0] # Verify generated C source - get_cs = tvm._ffi.get_global_func("runtime.module.ethosu.getcs") + get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs") cmms = get_cs(ethosu_module) cmms = bytes.fromhex(cmms) diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py b/tests/python/contrib/test_ethosu/test_encode_constants.py index a6bdf9be7765..cc3c68624242 100644 --- a/tests/python/contrib/test_ethosu/test_encode_constants.py +++ b/tests/python/contrib/test_ethosu/test_encode_constants.py @@ -302,8 +302,8 @@ def get_graph(): # Check tile address for the scalar constant input hasn't been # overwritten. extern_calls = tir_mod["main"].body.body.body.body.body - binary_elmtwise = extern_calls[-1].value - args = binary_elmtwise.args + binary_elementwise = extern_calls[-1].value + args = binary_elementwise.args reason = "Tile address overwritten" assert args[26] == 0, reason diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index 5fea513b17f9..8c3e4e31c1ca 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -734,10 +734,10 @@ def verify(ext_func): mod = create_graph() mod = partition_ethosu_by_table(mod, pattern_table) - mod["tvmgen_default_ethosu_main_0"] = dataflow_pattern.rewrite( - rewriter, mod["tvmgen_default_ethosu_main_0"] + mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( + rewriter, mod["tvmgen_default_ethos_u_main_0"] ) - verify(mod["tvmgen_default_ethosu_main_0"]) + verify(mod["tvmgen_default_ethos_u_main_0"]) @pytest.mark.parametrize(