From dd4242ffa472754325628c9459f4ef9f1534e7e2 Mon Sep 17 00:00:00 2001 From: Aleksei-grovety <113356454+Aleksei-grovety@users.noreply.github.com> Date: Fri, 28 Jul 2023 11:43:10 +0400 Subject: [PATCH] [microNPU][ETHOSU] Fix concatenation with reused buffers Add a pass to copy concatenation arguments which are used more than once in concatenation operation to prevent a situation where an argument used in multiple concatenations will be written to only one resulting buffer. --- .../relay/backend/contrib/ethosu/codegen.py | 101 ++++++++++++++++++ .../contrib/test_ethosu/test_codegen.py | 18 ++++ 2 files changed, 119 insertions(+) diff --git a/python/tvm/relay/backend/contrib/ethosu/codegen.py b/python/tvm/relay/backend/contrib/ethosu/codegen.py index f4cea5df1364..b2fc5f0af289 100644 --- a/python/tvm/relay/backend/contrib/ethosu/codegen.py +++ b/python/tvm/relay/backend/contrib/ethosu/codegen.py @@ -444,6 +444,106 @@ def replicate_pads(mod): return mod +class AnalyzeConcatArgs(ExprVisitor): + """Traverses the graph to determine which arguments were passed into the + concatenation operation and how many times they are used. The result is + maintained in `args_usage` and is a dictionary where the key is the concatenation argument and + the value is the number of uses of this argument. + + Attributes + ---------- + args_usage : Dict[tvm.relay.expr.Call, int] + Mapping from concatenation arguments to count their usage as concatenate arguments. + """ + + def __init__(self): + self.args_usage = defaultdict(int) + super().__init__() + + def visit_call(self, call: relay.Call): + args = [] + + # Expand tuples + for arg in call.args: + if isinstance(arg, relay.Tuple): + args.extend(arg.fields) + else: + args.append(arg) + + if isinstance(call.op, tvm.ir.Op) and call.op.name == "concatenate": + for arg in args: + if isinstance(arg, relay.Call): + self.args_usage[arg] += 1 + + super().visit_call(call) + + +class ConcatArgsCopier(ExprMutator): + """A pass for copying concatenation arguments that are used in multiple concatenation + operations. For a concatenation argument that is used n times, n - 1 copy operations + will be created. + + Attributes + ---------- + args_usage : Dict[tvm.relay.expr.Call, int] + Mapping from concatenation arguments to count their usage as concatenate arguments. + """ + + def __init__(self, args_usage): + super().__init__() + self.args_usage = args_usage + + def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call: + if isinstance(call.op, tvm.ir.Op) and call.op == relay.op.get("concatenate"): + args = [] + + # Expand tuples + for arg in call.args: + if isinstance(arg, relay.Tuple): + args.extend(arg.fields) + else: + args.append(arg) + new_args = [] + for arg in args: + visited = self.visit(arg) + if self.args_usage[arg] > 1: + # Add copy operation + lut = relay.const([], "int8") + new_op = op.ethosu_identity(visited, lut) + new_args.append(new_op) + self.args_usage[arg] -= 1 + else: + new_args.append(visited) + + new_args = [relay.Tuple(new_args)] + else: + new_args = [self.visit(arg) for arg in call.args] + new_op = self.visit(call.op) + new_call = _expr.CallWithFields( + call, new_op, new_args, call.attrs, call.type_args, None, call.span + ) + return new_call + + +@util.create_npu_function_pass(opt_level=1) +class CopyReusedConcatBuffers: + """Register CopyReusedConcatBuffers as a Relay pass.""" + + def transform_npu_function(self, _, func: relay.Function) -> relay.Function: + """A pass to copy concatenation arguments which are used more than once in + concatenation operation. This is the preparation for the next RemoveConcatenates + pass to prevent a situation where an argument used in multiple concatenations + will be written to only one resulting buffer.""" + + analyze = AnalyzeConcatArgs() + analyze.visit(func) + + return ConcatArgsCopier(analyze.args_usage).visit(func) + + def __call__(self, *args, **kwargs): + pass + + def IdentityOptimizer(): # pylint: disable=invalid-name """Pass that removes redundant identities @@ -585,6 +685,7 @@ def relay_to_tir(mod: tvm.ir.IRModule) -> tvm.ir.IRModule: """ mod = OutlineCompilerFunctions("ethos-u")(mod) mod = LegalizeEthosU()(mod) + mod = CopyReusedConcatBuffers()(mod) mod = LUTsOptimizer()(mod) mod = relay.transform.InferType()(mod) mod = IdentityOptimizer()(mod) diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index d56b8b6ec943..e094bb74b2e1 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -1170,6 +1170,24 @@ def concat_func(*inputs): infra.compare_tvm_with_tflite(concat_func, shapes, accel_type, enable_cascader=False) +def test_tflite_concat_with_reused_args(): + np.random.seed(0) + shapes = [(1, 1, 24, 1), (1, 1, 24, 1), (1, 1, 10, 1), (1, 1, 68, 1)] + axis = 2 + accel_type = "ethos-u55-256" + + @tf.function + def concat_func(*inputs): + op = tf.add(inputs[0], inputs[1]) + op2 = tf.concat((inputs[0], inputs[2], op), axis) + op = tf.concat((inputs[0], inputs[3], op), axis) + op = tf.nn.max_pool2d(op, (1, 1), (1, 2), "SAME") + op = tf.add(op, op2) + return op + + infra.compare_tvm_with_tflite(concat_func, shapes, accel_type, enable_cascader=False) + + @pytest.mark.parametrize("accel_type", ACCEL_TYPES) def test_tflite_sigmoid(accel_type): np.random.seed(0)