Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,106 @@ def replicate_pads(mod):
return mod


class AnalyzeConcatArgs(ExprVisitor):
"""Traverses the graph to determine which arguments were passed into the
concatenation operation and how many times they are used. The result is
maintained in `args_usage` and is a dictionary where the key is the concatenation argument and
the value is the number of uses of this argument.

Attributes
----------
args_usage : Dict[tvm.relay.expr.Call, int]
Mapping from concatenation arguments to count their usage as concatenate arguments.
"""

def __init__(self):
self.args_usage = defaultdict(int)
super().__init__()

def visit_call(self, call: relay.Call):
args = []

# Expand tuples
for arg in call.args:
if isinstance(arg, relay.Tuple):
args.extend(arg.fields)
else:
args.append(arg)

if isinstance(call.op, tvm.ir.Op) and call.op.name == "concatenate":
for arg in args:
if isinstance(arg, relay.Call):
self.args_usage[arg] += 1

super().visit_call(call)


class ConcatArgsCopier(ExprMutator):
"""A pass for copying concatenation arguments that are used in multiple concatenation
operations. For a concatenation argument that is used n times, n - 1 copy operations
will be created.

Attributes
----------
args_usage : Dict[tvm.relay.expr.Call, int]
Mapping from concatenation arguments to count their usage as concatenate arguments.
"""

def __init__(self, args_usage):
super().__init__()
self.args_usage = args_usage

def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call:
if isinstance(call.op, tvm.ir.Op) and call.op == relay.op.get("concatenate"):
args = []

# Expand tuples
for arg in call.args:
if isinstance(arg, relay.Tuple):
args.extend(arg.fields)
else:
args.append(arg)
new_args = []
for arg in args:
visited = self.visit(arg)
if self.args_usage[arg] > 1:
# Add copy operation
lut = relay.const([], "int8")
new_op = op.ethosu_identity(visited, lut)
new_args.append(new_op)
self.args_usage[arg] -= 1
else:
new_args.append(visited)

new_args = [relay.Tuple(new_args)]
else:
new_args = [self.visit(arg) for arg in call.args]
new_op = self.visit(call.op)
new_call = _expr.CallWithFields(
call, new_op, new_args, call.attrs, call.type_args, None, call.span
)
return new_call


@util.create_npu_function_pass(opt_level=1)
class CopyReusedConcatBuffers:
"""Register CopyReusedConcatBuffers as a Relay pass."""

def transform_npu_function(self, _, func: relay.Function) -> relay.Function:
"""A pass to copy concatenation arguments which are used more than once in
concatenation operation. This is the preparation for the next RemoveConcatenates
pass to prevent a situation where an argument used in multiple concatenations
will be written to only one resulting buffer."""

analyze = AnalyzeConcatArgs()
analyze.visit(func)

return ConcatArgsCopier(analyze.args_usage).visit(func)

def __call__(self, *args, **kwargs):
pass


def IdentityOptimizer(): # pylint: disable=invalid-name
"""Pass that removes redundant identities

Expand Down Expand Up @@ -585,6 +685,7 @@ def relay_to_tir(mod: tvm.ir.IRModule) -> tvm.ir.IRModule:
"""
mod = OutlineCompilerFunctions("ethos-u")(mod)
mod = LegalizeEthosU()(mod)
mod = CopyReusedConcatBuffers()(mod)
mod = LUTsOptimizer()(mod)
mod = relay.transform.InferType()(mod)
mod = IdentityOptimizer()(mod)
Expand Down
18 changes: 18 additions & 0 deletions tests/python/contrib/test_ethosu/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1170,6 +1170,24 @@ def concat_func(*inputs):
infra.compare_tvm_with_tflite(concat_func, shapes, accel_type, enable_cascader=False)


def test_tflite_concat_with_reused_args():
np.random.seed(0)
shapes = [(1, 1, 24, 1), (1, 1, 24, 1), (1, 1, 10, 1), (1, 1, 68, 1)]
axis = 2
accel_type = "ethos-u55-256"

@tf.function
def concat_func(*inputs):
op = tf.add(inputs[0], inputs[1])
op2 = tf.concat((inputs[0], inputs[2], op), axis)
op = tf.concat((inputs[0], inputs[3], op), axis)
op = tf.nn.max_pool2d(op, (1, 1), (1, 2), "SAME")
op = tf.add(op, op2)
return op

infra.compare_tvm_with_tflite(concat_func, shapes, accel_type, enable_cascader=False)


@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
def test_tflite_sigmoid(accel_type):
np.random.seed(0)
Expand Down