diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index 2806ef8a4699..7ed69e1e9bf8 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -1184,7 +1184,7 @@ def callback( # Find the tensors that are inputs to the concat and the scales and zero points concat_args = list() for arg in post.args: - if isinstance(arg, tvm.relay.expr.Call): + if isinstance(arg, (tvm.relay.expr.Call, tvm.relay.expr.TupleGetItem)): concat_args.append(arg) axis = post.op.body.attrs.axis diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index 66809a775f48..f69c114cabd1 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -1170,6 +1170,22 @@ def concat_func(*inputs): infra.compare_tvm_with_tflite(concat_func, shapes, accel_type, enable_cascader=False) +def test_tflite_unstack_concat(): + np.random.seed(0) + shapes = [(2, 4, 16)] + axis = 1 + accel_type = "ethos-u55-256" + + @tf.function + def concat_func(input): + inputs = tf.unstack(input) + inputs.reverse() + op = tf.concat(inputs, axis) + return op + + infra.compare_tvm_with_tflite(concat_func, shapes, accel_type, enable_cascader=False) + + def test_tflite_concat_with_reused_args(): np.random.seed(0) shapes = [(1, 1, 24, 1), (1, 1, 24, 1), (1, 1, 10, 1), (1, 1, 68, 1)]