@@ -25,12 +25,18 @@ class RemoveCloneOpsTransform(ExportPass):
2525 exir_ops .edge .dim_order_ops ._clone_dim_order .default ,
2626 }
2727
28- def __init__ (self , preserve_input_output_copies : bool = False ) -> None :
28+ def __init__ (
29+ self ,
30+ preserve_input_output_copies : bool = False ,
31+ eliminate_quant_dequant_pairs : bool = True ,
32+ ) -> None :
2933 super ().__init__ ()
3034 self ._preserve_input_output_copies = preserve_input_output_copies
35+ self ._eliminate_quant_dequant_pairs = eliminate_quant_dequant_pairs
3136
32- def _remove (self , graph_module : torch .fx .GraphModule ) -> None :
37+ def _remove (self , graph_module : torch .fx .GraphModule ) -> bool :
3338 dequant_nodes = []
39+ modified = False
3440
3541 for n in graph_module .graph .nodes :
3642 if n .target not in self .clone_ops :
@@ -44,20 +50,26 @@ def _remove(self, graph_module: torch.fx.GraphModule) -> None:
4450 if self ._is_input_output_copy (n ) and self ._preserve_input_output_copies :
4551 continue
4652
53+ modified = True
4754 to_be_removed = n
4855 for user_n in list (n .users .keys ()):
4956 user_n .replace_input_with (n , n .args [0 ])
5057 if n .args [0 ].target in _DEQUANT_OPS :
5158 dequant_nodes += [n .args [0 ]]
5259 graph_module .graph .erase_node (to_be_removed )
5360
54- eliminate_dq_q (graph_module , dequant_nodes )
61+ if self ._eliminate_quant_dequant_pairs :
62+ eliminate_dq_q (graph_module , dequant_nodes )
63+
64+ return modified
5565
5666 def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
57- self ._remove (graph_module )
58- graph_module .recompile ()
59- dead_code_elimination_pass (graph_module )
60- return PassResult (graph_module , True )
67+ if self ._remove (graph_module ):
68+ graph_module .recompile ()
69+ dead_code_elimination_pass (graph_module )
70+ return PassResult (graph_module , True )
71+ else :
72+ return PassResult (graph_module , False )
6173
6274 def _is_non_identity_clone (self , node : torch .fx .Node ) -> bool :
6375 """Return True if clone has modified memory layout or dim order."""
0 commit comments