From 20d1223c274cd139463a1e7ff3d8c5e9889f2162 Mon Sep 17 00:00:00 2001 From: csric Date: Mon, 27 Mar 2023 18:53:14 +0800 Subject: [PATCH] [NFC] polish colossalai/fx/passes/split_module.py code style --- colossalai/fx/passes/split_module.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/colossalai/fx/passes/split_module.py b/colossalai/fx/passes/split_module.py index bc257edc8c89..9bc4bf1f5c42 100644 --- a/colossalai/fx/passes/split_module.py +++ b/colossalai/fx/passes/split_module.py @@ -1,9 +1,10 @@ +import inspect +from typing import Any, Callable, Dict, List, Optional + import torch -from torch.fx.graph_module import GraphModule -from typing import Callable, List, Dict, Any, Optional -from torch.fx._compatibility import compatibility from packaging import version -import inspect +from torch.fx._compatibility import compatibility +from torch.fx.graph_module import GraphModule @compatibility(is_backward_compatible=True) @@ -38,7 +39,7 @@ def split_module( m: GraphModule, root_m: torch.nn.Module, split_callback: Callable[[torch.fx.node.Node], int], - merge_output = False, + merge_output=False, ): """ Adapted from https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/split_module.py @@ -132,10 +133,8 @@ def record_cross_partition_use(def_node: torch.fx.node.Node, use_partition.inputs.setdefault(def_node.name) if def_partition_name is not None: use_partition.partitions_dependent_on.setdefault(def_partition_name) - - def record_output( - def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node] - ): # noqa: B950 + + def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]): # noqa: B950 def_partition_name = getattr(def_node, "_fx_partition", None) use_partition_name = getattr(use_node, "_fx_partition", None) if def_partition_name != use_partition_name: @@ -291,7 +290,7 @@ def record_output( for partition_name in sorted_partitions: partition = partitions[partition_name] - + new_gm = torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph) return new_gm