diff --git a/python/tvm/relay/op/contrib/ethosn.py b/python/tvm/relay/op/contrib/ethosn.py index 3e10f3d60415..7acaee9706c2 100644 --- a/python/tvm/relay/op/contrib/ethosn.py +++ b/python/tvm/relay/op/contrib/ethosn.py @@ -129,6 +129,7 @@ def partition_for_ethosn(mod, params=None, **opts): passes = [ transform.InferType(), + transform.FoldConstant(fold_qnn=True), transform.MergeComposite(pattern_table()), transform.AnnotateTarget("ethos-n"), transform.MergeCompilerRegions(), diff --git a/tests/python/contrib/test_ethosn/test_addition.py b/tests/python/contrib/test_ethosn/test_addition.py index 9841e798aff4..5813ef7b9d44 100644 --- a/tests/python/contrib/test_ethosn/test_addition.py +++ b/tests/python/contrib/test_ethosn/test_addition.py @@ -41,20 +41,28 @@ def _get_model( ): """Return a model and any parameters it may have""" - iinfo = np.iinfo(dtype) - data_min = iinfo.min - data_max = iinfo.max + def create_or_assign_constant(shape, dtype, default_data): + """Creates new numpy array or assigns default_data if available.""" + + iinfo = np.iinfo(dtype) + data_min = iinfo.min + data_max = iinfo.max + + nparray = None + if default_data: + nparray = np.array(default_data, dtype=dtype).reshape(shape) + else: + nparray = np.random.randint(data_min, data_max + 1, size=shape, dtype=dtype) + + return relay.const(nparray, dtype=dtype) if lhs_is_constant: - a_data = np.array(constant_data, dtype=dtype).reshape(lhs_shape) - a = relay.const(a_data, dtype=dtype) + a = create_or_assign_constant(lhs_shape, dtype, constant_data) else: a = relay.var("a", shape=lhs_shape, dtype=dtype) if rhs_is_constant: - b_data = np.array(constant_data, dtype=dtype).reshape(rhs_shape) - np.random.randint(data_min, data_max + 1, size=rhs_shape, dtype=dtype) - b = relay.const(b_data, dtype=dtype) + b = create_or_assign_constant(rhs_shape, dtype, constant_data) else: b = relay.var("b", shape=rhs_shape, dtype=dtype) @@ -125,6 +133,46 @@ def test_addition(dtype, shape): tei.verify(outputs, dtype, 1) +@requires_ethosn +@pytest.mark.parametrize("dtype", ["uint8", "int8"]) +@pytest.mark.parametrize( + "lhs_shape,lhs_is_constant,rhs_shape,rhs_is_constant", + [ + ((1, 4, 4, 8), True, (1, 1, 1, 8), True), + ((4,), True, (1, 16, 12, 4), True), + ((1, 1, 1, 8), True, (1, 4, 4, 8), True), + ((1, 16, 12, 4), True, (4,), True), + ], +) +def test_addition_both_inputs_constants( + dtype, lhs_shape, lhs_is_constant, rhs_shape, rhs_is_constant +): + """Check if addition is simplified when both inputs are constants.""" + np.random.seed(0) + + lhs_zp, lhs_sc, rhs_zp, rhs_sc, out_zp, out_sc = _get_addition_qnn_params(dtype) + + model = _get_model( + lhs_shape, + rhs_shape, + lhs_zp, + lhs_sc, + rhs_zp, + rhs_sc, + out_zp, + out_sc, + dtype, + lhs_is_constant=lhs_is_constant, + rhs_is_constant=rhs_is_constant, + ) + from tvm.relay.op.contrib import partition_for_ethosn # pylint: disable=import-outside-toplevel + + mod = tei.make_module(model, {}) + assert "qnn.add" in mod.astext(False) + mod = partition_for_ethosn(mod, {}) + assert "qnn.add" not in mod.astext(False) + + @requires_ethosn @pytest.mark.parametrize("dtype", ["uint8", "int8"]) @pytest.mark.parametrize( @@ -145,9 +193,6 @@ def test_addition_to_depthwise(dtype, lhs_shape, lhs_is_constant, rhs_shape, rhs data_max = iinfo.max lhs_zp, lhs_sc, rhs_zp, rhs_sc, out_zp, out_sc = _get_addition_qnn_params(dtype) - constant_shape = lhs_shape if lhs_is_constant else rhs_shape - constant_data = np.random.randint(data_min, data_max + 1, size=constant_shape, dtype=dtype) - model = _get_model( lhs_shape, rhs_shape, @@ -160,7 +205,6 @@ def test_addition_to_depthwise(dtype, lhs_shape, lhs_is_constant, rhs_shape, rhs dtype, lhs_is_constant=lhs_is_constant, rhs_is_constant=rhs_is_constant, - constant_data=constant_data, ) input_shape = rhs_shape if lhs_is_constant else lhs_shape input_name = "b" if lhs_is_constant else "a" diff --git a/tests/python/contrib/test_ethosn/test_networks.py b/tests/python/contrib/test_ethosn/test_networks.py index 23ff5207fbcd..dfbd262abf96 100644 --- a/tests/python/contrib/test_ethosn/test_networks.py +++ b/tests/python/contrib/test_ethosn/test_networks.py @@ -218,6 +218,6 @@ def test_ssd_mobilenet_v1(): input_dict={"normalized_input_image_tensor": (1, 300, 300, 3)}, compile_hash=_compile_hash, output_count=4, - host_ops=26, + host_ops=14, npu_partitions=1, )