Skip to content

Commit c659034

Browse files
ethansfngfacebook-github-bot
authored andcommitted
Handle same tensor appearing multiple times in the cat input
Summary: GenerateCatNopConstraints pass in memory planning was crashing when the same tensor appeared multiple times in a single cat operation's input list at different positions. Differential Revision: D88439174
1 parent 37b9041 commit c659034

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

backends/cadence/aot/memory_constraints.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,10 @@ def is_slice_view(self, node: torch.fx.Node) -> bool:
417417
return not self.constraint.is_alias_of(source_info.source, node)
418418
return False
419419

420+
def has_relative_placement_constraint(self, node: torch.fx.Node) -> bool:
421+
"""Return if `node` already has any relative placement constraint."""
422+
return self.constraint.get_relative_placement_source(node) is not None
423+
420424
# Return true if the cat node performs concatenation along outermost dimension
421425
def is_cat_along_outermost_dim(
422426
self, graph_module: torch.fx.GraphModule, cat_node: torch.fx.Node
@@ -481,6 +485,17 @@ def is_removable_cat_op(
481485
if any(self.is_slice_view(arg) for arg in cat_tensors):
482486
return False
483487

488+
# If any of the tensors already has a relative placement constraint,
489+
# we cannot add a new constraint for this cat without conflicting.
490+
# This can happen when a tensor is used in multiple cat operations.
491+
if any(self.has_relative_placement_constraint(arg) for arg in cat_tensors):
492+
return False
493+
494+
# If the same tensor appears multiple times in the cat inputs,
495+
# we cannot place it at multiple different offsets relative to the output.
496+
if len(cat_tensors) != len(set(cat_tensors)):
497+
return False
498+
484499
# Many ops in HiFi require the input to be aligned to 8-byte boundary.
485500
# If the cat is not the graph's output, then ensure that the relative
486501
# offset of any concatenated non-placeholder tensor is a multiple of

0 commit comments

Comments
 (0)