From 24709eb3cf431c775f187f7f9651658f38225481 Mon Sep 17 00:00:00 2001 From: Cypher30 <1529318642@qq.com> Date: Sat, 31 Dec 2022 19:52:33 +0800 Subject: [PATCH 1/6] [autoparallel] hook node meta on graph nodes for checkpoint solver --- .../meta_registry/binary_elementwise_ops.py | 2 +- .../auto_parallel/meta_profiler/metainfo.py | 22 +--- .../passes/comm_metainfo_pass.py | 120 ++++++++++++++++++ .../auto_parallel/passes/meta_info_prop.py | 5 +- .../passes/runtime_apply_pass.py | 49 ------- .../binary_elementwise_handler.py | 2 +- .../tensor_shard/node_handler/node_handler.py | 3 +- .../node_handler/reshape_handler.py | 4 +- .../node_handler/unary_elementwise_handler.py | 4 +- 9 files changed, 134 insertions(+), 77 deletions(-) create mode 100644 colossalai/auto_parallel/passes/comm_metainfo_pass.py diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py b/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py index eb8042368d04..b4cc58d05c0c 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py @@ -60,7 +60,7 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) # store fwd_in, fwd_buffer, fwd_out - fwd_in = [torch.zeros_like(input_op_data.data, device='meta')] + fwd_in = [torch.zeros_like(input_op_data.data, device='meta'), torch.zeros_like(other_op_data.data, device='meta')] fwd_buffer = [] fwd_out = [torch.zeros_like(output_op_data.data, device='meta')] diff --git a/colossalai/auto_parallel/meta_profiler/metainfo.py b/colossalai/auto_parallel/meta_profiler/metainfo.py index 1f34637139e6..ff76e3059fef 100644 --- a/colossalai/auto_parallel/meta_profiler/metainfo.py +++ b/colossalai/auto_parallel/meta_profiler/metainfo.py @@ -1,6 +1,5 @@ from typing import Callable, List -import numpy as np import torch from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( @@ -71,25 +70,12 @@ def target(self, target: Callable) -> None: if self._strategy is not None and self._target is not None: self.compute_metainfo() - def compute_sharded_tensor(self, operation_data: OperationData, sharding_spec: ShardingSpec) -> torch.Tensor: + def compute_sharded_opdata(self, operation_data: OperationData, sharding_spec: ShardingSpec) -> torch.Tensor: """ - Compute sharded meta tensor based on the given data and sharding spec. + Compute sharded opdata based on the given data and sharding spec. """ - shard_sequnce = sharding_spec.sharding_sequence - device_mesh = sharding_spec.device_mesh - shape = operation_data.data.shape - - new_shape = [] - for dim, shard in zip(shape, shard_sequnce): - if shard.is_replica: - # replica - new_shape.append(dim) - else: - # sharded according to device_mesh shape - new_shape.append(dim // np.prod(np.array([device_mesh.mesh_shape[i] for i in shard.shard_list]))) - return OperationData(name=operation_data.name, - data=torch.zeros(new_shape, device="meta"), + data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"), type=operation_data.type, logical_shape=operation_data.logical_shape) @@ -113,7 +99,7 @@ def compute_metainfo(self): save_fwd_in = self._target.__class__ not in NO_SAVE_ACTIVATION # construct args for meta_func - args = [self.compute_sharded_tensor(k, v) for k, v in self._strategy.sharding_specs.items()] + args = [self.compute_sharded_opdata(k, v) for k, v in self._strategy.sharding_specs.items()] # construct kwargs if self.target in INPLACE_MODULE: diff --git a/colossalai/auto_parallel/passes/comm_metainfo_pass.py b/colossalai/auto_parallel/passes/comm_metainfo_pass.py new file mode 100644 index 000000000000..9a813fc5716c --- /dev/null +++ b/colossalai/auto_parallel/passes/comm_metainfo_pass.py @@ -0,0 +1,120 @@ +from copy import deepcopy +from typing import Dict, List + +import torch +from torch.fx import GraphModule +from torch.fx.node import Node + +from colossalai.auto_parallel.meta_profiler import MetaInfo +from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply, runtime_comm_spec_apply +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommAction, + CommType, + OperationData, + OperationDataType, + TrainCycleItem, +) +from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.comm_spec import CommSpec +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.tensor.sharding_spec import ShardingSpec + +shape_consistency_manager = ShapeConsistencyManager() + + +def _construct_meta_info(node: Node, origin_sharding_spec: ShardingSpec, + target_sharding_spec: ShardingSpec) -> MetaInfo: + # get comm_action_sequence and total_cost from shape_consistency_manager + _, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency( + origin_sharding_spec, target_sharding_spec) + + meta_info = MetaInfo() + # NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel + # get mem cost for MetaInfo + mem_cost = shape_consistency_manager.mem_cost(comm_action_sequence) + # extract user that has _meta_data and extract element length + input_node = next(n for n in node._input_nodes if hasattr(n, '_meta_data')) + element_length = input_node._meta_data.element_size() + + mem_cost.fwd.activation *= element_length + mem_cost.fwd.temp *= element_length + mem_cost.bwd.activation *= element_length + mem_cost.bwd.temp *= element_length + mem_cost.total.activation *= element_length + + meta_info.memory_cost = mem_cost + + # get computation cost for MetaInfo + meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length, + total_cost['backward'] * element_length, + total_cost['total'] * element_length) + + # get tensor shape for MetaInfo + origin_sharding_spec: ShardingSpec + target_sharding_spec: ShardingSpec + input_shape = origin_sharding_spec.get_sharded_shape_per_device() + output_shape = target_sharding_spec.get_sharded_shape_per_device() + + meta_info.fwd_in = [torch.rand(input_shape, device='meta')] + meta_info.fwd_buffer = [] + meta_info.fwd_out = [torch.rand(output_shape, device='meta')] + + return meta_info + + +def _runtime_apply_meta_info(node: Node, original_sharding_spec_dict, sharding_spec_dict) -> MetaInfo: + """ + This method is used to construct `MetaInto` for shape consistency node + """ + + # extract node index and user node index + args = node.args + node_index, user_node_index = args[3], args[4] + origin_sharding_spec, target_sharding_spec = original_sharding_spec_dict[node_index], sharding_spec_dict[ + node_index][user_node_index] + + return _construct_meta_info(node, origin_sharding_spec, target_sharding_spec) + + +def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> MetaInfo: + # extract node_index and op_data_name + node_index, op_data_name = node.args[2], node.args[3] + + comm_action = comm_actions_dict[node_index][op_data_name] + if isinstance(comm_action.comm_spec, CommSpec): + # this case is for all_reduce, there will be no memory cost + meta_info = MetaInfo() + output_node = next(n for n in node.users if hasattr(n, '_meta_data')) + element_length = output_node._meta_data.element_size() + + total_cost = comm_action.comm_spec.get_comm_cost() + meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length, + total_cost['backward'] * element_length, + total_cost['total'] * element_length) + + input_shape = output_shape = comm_action.comm_spec.sharding_spec.get_sharded_shape_per_device() + meta_info.fwd_in = [torch.rand(input_shape, device='meta')] + meta_info.fwd_buffer = [] + meta_info.fwd_out = [torch.rand(output_shape, device='meta')] + else: + # this case will be handled by shape consistency manager + origin_sharding_spec, target_sharding_spec = comm_action.comm_spec['src_spec'], comm_action.comm_spec[ + 'tgt_spec'] + meta_info = _construct_meta_info(node, origin_sharding_spec, target_sharding_spec) + + return meta_info + + +def comm_metainfo_pass(gm: GraphModule, sharding_spec_dict: Dict, original_sharding_spec_dict: Dict, + comm_actions_dict: Dict): + """ + The method manages all the metainfo of the communication node (run_time_apply, runtime_comm_spec_apply) in the graph. + """ + for node in gm.graph.nodes: + if node.target == runtime_apply: + setattr(node, 'best_metainfo', + _runtime_apply_meta_info(node, original_sharding_spec_dict, sharding_spec_dict)) + elif node.target == runtime_comm_spec_apply: + setattr(node, 'best_metainfo', _runtime_comm_spec_apply_meta_info(node, comm_actions_dict)) + else: + pass diff --git a/colossalai/auto_parallel/passes/meta_info_prop.py b/colossalai/auto_parallel/passes/meta_info_prop.py index 1628bb28596a..7bf34ed5fbc5 100644 --- a/colossalai/auto_parallel/passes/meta_info_prop.py +++ b/colossalai/auto_parallel/passes/meta_info_prop.py @@ -59,6 +59,7 @@ def run(self) -> GraphModule: """ for node in self.module.graph.nodes: node: Node + print(node) self.func_dict[node.op](node) @compatibility(is_backward_compatible=False) @@ -68,7 +69,7 @@ def placeholder_handler(self, node: Node) -> None: """ graph_info = GraphInfo() out = _normalize_tuple(getattr(node, '_meta_data', None)) - graph_info.fwd_out = list(out) + graph_info.fwd_out = list(out) if out[0] is not None else [] node.meta = {**asdict(graph_info)} @compatibility(is_backward_compatible=False) @@ -97,7 +98,7 @@ def node_handler(self, node: Node) -> None: """ Handle other kind of nodes """ - assert hasattr(node, 'best_metainfo'), f"Cannot find best_metainfo in node {node}" + assert hasattr(node, 'best_metainfo'), f"Cannot find best_metainfo in node {node}, {node.op}" graph_info = GraphInfo() meta_info = node.best_metainfo meta_info: MetaInfo diff --git a/colossalai/auto_parallel/passes/runtime_apply_pass.py b/colossalai/auto_parallel/passes/runtime_apply_pass.py index 5d224542c2f8..7f2aac42b7f8 100644 --- a/colossalai/auto_parallel/passes/runtime_apply_pass.py +++ b/colossalai/auto_parallel/passes/runtime_apply_pass.py @@ -47,53 +47,6 @@ def runtime_apply_for_iterable_object(node: Node, origin_dict: Dict, input_dict: return rst -def construct_meta_info(node: Node, user_node: Node) -> MetaInfo: - """ - This method is used to construct `MetaInto` for shape consistency node - TODO: Actually we could attain the cost information from resharding cost in node - handler, we should modify this part in the future. - """ - - def compute_shape(sharding_spec: ShardingSpec): - shape = sharding_spec.entire_shape - new_shape = [] - for dim, shard in sharding_spec.dim_partition_dict.items(): - new_shape.append(shape[dim] // len(shard)) - return new_shape - - meta_info = MetaInfo() - origin_sharding_spec, target_sharding_spec = node.sharding_spec, user_node.best_strategy.get_sharding_spec_by_name( - str(node.name)) - _, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency( - origin_sharding_spec, target_sharding_spec) - - # NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel - # get mem cost for MetaInfo - mem_cost = shape_consistency_manager.mem_cost(comm_action_sequence) - element_length = node._meta_data.element_size() - mem_cost.fwd.activation *= element_length - mem_cost.fwd.temp *= element_length - mem_cost.bwd.activation *= element_length - mem_cost.bwd.temp *= element_length - mem_cost.total.activation *= element_length - - meta_info.memory_cost = mem_cost - - # get computation cost for MetaInfo - compute_cost = TrainCycleItem(total_cost['forward'], total_cost['backward'], total_cost['total']) - meta_info.compute_cost = compute_cost - - # get tensor shape for MetaInfo - input_shape = compute_shape(origin_sharding_spec) - output_shape = compute_shape(target_sharding_spec) - - meta_info.fwd_in = [torch.rand(input_shape, device='meta')] - meta_info.fwd_buffer = [] - meta_info.fwd_out = [torch.rand(output_shape, device='meta')] - - return meta_info - - def runtime_comm_spec_apply(tensor: torch.Tensor, comm_actions_dict: Dict, node_index: int, op_data_name: str): """ This method will be invoked during runtime to apply the comm action following the instruction of comm spec. @@ -175,8 +128,6 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule): runtime_apply, args=(node, origin_dict_node, input_dict_node, node_to_index_dict[node], user_node_index)) - meta_info = construct_meta_info(node, user_node) - setattr(shape_consistency_node, 'best_metainfo', meta_info) new_args = list(user_node.args) new_kwargs = dict(user_node.kwargs) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py index e8ae363e97a1..f510f74776b6 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py @@ -16,7 +16,7 @@ @operator_registry.register(BCAST_FUNC_OP) -class BinaryElementwiseHandler(NodeHandler): +class BinaryElementwiseHandler(MetaInfoNodeHandler): """ An BinaryBcastOpHandler is a node handler which deals with operations which have two operands and broadcasting occurs such as torch.add. diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py index 7dea256b3ac4..af3cb5810d11 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py @@ -138,8 +138,7 @@ def get_target_function(self) -> callable: return None if self.node.op == 'call_module': - submod = self.node.graph.owning_module.get_submodule(self.node.target) - target = type(submod) + target = self.node.graph.owning_module.get_submodule(self.node.target) elif self.node.op == 'call_function': target = self.node.target elif self.node.op == 'call_method': diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py index b463487165cb..7763b1884025 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py @@ -3,7 +3,7 @@ import torch from ..sharding_strategy import OperationData, OperationDataType -from .node_handler import NodeHandler +from .node_handler import MetaInfoNodeHandler, NodeHandler from .registry import operator_registry from .strategy import ReshapeGenerator, StrategyGenerator @@ -13,7 +13,7 @@ @operator_registry.register(torch.flatten) @operator_registry.register(torch.Tensor.unsqueeze) @operator_registry.register(torch.nn.AdaptiveAvgPool2d) -class ReshapeHandler(NodeHandler): +class ReshapeHandler(MetaInfoNodeHandler): """ A ReshapeHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape. """ diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py index bda160906517..0362de780d7a 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py @@ -3,7 +3,7 @@ import torch from ..sharding_strategy import OperationData, OperationDataType -from .node_handler import NodeHandler +from .node_handler import MetaInfoNodeHandler, NodeHandler from .registry import operator_registry from .strategy import StrategyGenerator, UnaryElementwiseGenerator @@ -19,7 +19,7 @@ @operator_registry.register(torch.nn.modules.dropout.Dropout) @operator_registry.register(torch.Tensor.contiguous) @operator_registry.register(torch.nn.functional.dropout) -class UnaryElementwiseHandler(NodeHandler): +class UnaryElementwiseHandler(MetaInfoNodeHandler): """ A UnaryElementwiseHandler which deals with the sharding strategies for UnaryElementwise Op. """ From b1caa03711e69c3bf42a972f7f070eabb9c566dc Mon Sep 17 00:00:00 2001 From: Cypher30 <1529318642@qq.com> Date: Sat, 31 Dec 2022 20:08:03 +0800 Subject: [PATCH 2/6] [autoparallel] polish code --- colossalai/auto_parallel/passes/comm_metainfo_pass.py | 2 ++ colossalai/auto_parallel/passes/meta_info_prop.py | 1 - 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/colossalai/auto_parallel/passes/comm_metainfo_pass.py b/colossalai/auto_parallel/passes/comm_metainfo_pass.py index 9a813fc5716c..d8cc6088bb42 100644 --- a/colossalai/auto_parallel/passes/comm_metainfo_pass.py +++ b/colossalai/auto_parallel/passes/comm_metainfo_pass.py @@ -10,6 +10,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( CommAction, CommType, + MemoryCost, OperationData, OperationDataType, TrainCycleItem, @@ -84,6 +85,7 @@ def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> M if isinstance(comm_action.comm_spec, CommSpec): # this case is for all_reduce, there will be no memory cost meta_info = MetaInfo() + meta_info.memory_cost = TrainCycleItem(MemoryCost(), MemoryCost(), MemoryCost) output_node = next(n for n in node.users if hasattr(n, '_meta_data')) element_length = output_node._meta_data.element_size() diff --git a/colossalai/auto_parallel/passes/meta_info_prop.py b/colossalai/auto_parallel/passes/meta_info_prop.py index 7bf34ed5fbc5..a23b3d4228b7 100644 --- a/colossalai/auto_parallel/passes/meta_info_prop.py +++ b/colossalai/auto_parallel/passes/meta_info_prop.py @@ -59,7 +59,6 @@ def run(self) -> GraphModule: """ for node in self.module.graph.nodes: node: Node - print(node) self.func_dict[node.op](node) @compatibility(is_backward_compatible=False) From b161a876218e6e0f5264cf18ee3bf9aaef70b56c Mon Sep 17 00:00:00 2001 From: Cypher30 <1529318642@qq.com> Date: Sat, 31 Dec 2022 20:20:07 +0800 Subject: [PATCH 3/6] [autoparallel] restore some node handlers --- .../tensor_shard/node_handler/binary_elementwise_handler.py | 2 +- .../tensor_shard/node_handler/reshape_handler.py | 4 ++-- .../tensor_shard/node_handler/unary_elementwise_handler.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py index f510f74776b6..e8ae363e97a1 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py @@ -16,7 +16,7 @@ @operator_registry.register(BCAST_FUNC_OP) -class BinaryElementwiseHandler(MetaInfoNodeHandler): +class BinaryElementwiseHandler(NodeHandler): """ An BinaryBcastOpHandler is a node handler which deals with operations which have two operands and broadcasting occurs such as torch.add. diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py index 7763b1884025..b463487165cb 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py @@ -3,7 +3,7 @@ import torch from ..sharding_strategy import OperationData, OperationDataType -from .node_handler import MetaInfoNodeHandler, NodeHandler +from .node_handler import NodeHandler from .registry import operator_registry from .strategy import ReshapeGenerator, StrategyGenerator @@ -13,7 +13,7 @@ @operator_registry.register(torch.flatten) @operator_registry.register(torch.Tensor.unsqueeze) @operator_registry.register(torch.nn.AdaptiveAvgPool2d) -class ReshapeHandler(MetaInfoNodeHandler): +class ReshapeHandler(NodeHandler): """ A ReshapeHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape. """ diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py index 0362de780d7a..bda160906517 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py @@ -3,7 +3,7 @@ import torch from ..sharding_strategy import OperationData, OperationDataType -from .node_handler import MetaInfoNodeHandler, NodeHandler +from .node_handler import NodeHandler from .registry import operator_registry from .strategy import StrategyGenerator, UnaryElementwiseGenerator @@ -19,7 +19,7 @@ @operator_registry.register(torch.nn.modules.dropout.Dropout) @operator_registry.register(torch.Tensor.contiguous) @operator_registry.register(torch.nn.functional.dropout) -class UnaryElementwiseHandler(MetaInfoNodeHandler): +class UnaryElementwiseHandler(NodeHandler): """ A UnaryElementwiseHandler which deals with the sharding strategies for UnaryElementwise Op. """ From ec8150fe7989e5281998fa1aefc179c8e8c43222 Mon Sep 17 00:00:00 2001 From: Cypher30 <1529318642@qq.com> Date: Mon, 2 Jan 2023 09:15:40 +0800 Subject: [PATCH 4/6] colossalai/auto_parallel/passes/meta_info_prop.py --- colossalai/auto_parallel/passes/meta_info_prop.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/colossalai/auto_parallel/passes/meta_info_prop.py b/colossalai/auto_parallel/passes/meta_info_prop.py index a23b3d4228b7..cf5f9d6d0d50 100644 --- a/colossalai/auto_parallel/passes/meta_info_prop.py +++ b/colossalai/auto_parallel/passes/meta_info_prop.py @@ -159,4 +159,11 @@ def node_handler(self, node: Node) -> None: graph_info.fwd_mem_tmp = memory_cost.fwd.temp graph_info.bwd_mem_tmp = memory_cost.bwd.temp + # fetch flop information + # here we use fwd_time and bwd_time to deal with the case that + # communication cost is a float + compute_cost = meta_info.compute_cost + graph_info.fwd_time = compute_cost.fwd + graph_info.bwd_time = compute_cost.bwd + node.meta = {**asdict(graph_info)} From d5e3d6ee559c7bd7d6dbd1d351bb590ccd6a0cff Mon Sep 17 00:00:00 2001 From: Cypher30 <1529318642@qq.com> Date: Mon, 2 Jan 2023 09:23:58 +0800 Subject: [PATCH 5/6] [autoparallel] remove some unused import --- .../auto_parallel/passes/comm_metainfo_pass.py | 13 ++----------- colossalai/auto_parallel/passes/meta_info_prop.py | 7 +++---- 2 files changed, 5 insertions(+), 15 deletions(-) diff --git a/colossalai/auto_parallel/passes/comm_metainfo_pass.py b/colossalai/auto_parallel/passes/comm_metainfo_pass.py index d8cc6088bb42..5ab6289b7de7 100644 --- a/colossalai/auto_parallel/passes/comm_metainfo_pass.py +++ b/colossalai/auto_parallel/passes/comm_metainfo_pass.py @@ -1,5 +1,4 @@ -from copy import deepcopy -from typing import Dict, List +from typing import Dict import torch from torch.fx import GraphModule @@ -7,15 +6,7 @@ from colossalai.auto_parallel.meta_profiler import MetaInfo from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply, runtime_comm_spec_apply -from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( - CommAction, - CommType, - MemoryCost, - OperationData, - OperationDataType, - TrainCycleItem, -) -from colossalai.device.device_mesh import DeviceMesh +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem from colossalai.tensor.comm_spec import CommSpec from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.sharding_spec import ShardingSpec diff --git a/colossalai/auto_parallel/passes/meta_info_prop.py b/colossalai/auto_parallel/passes/meta_info_prop.py index cf5f9d6d0d50..037165756ee0 100644 --- a/colossalai/auto_parallel/passes/meta_info_prop.py +++ b/colossalai/auto_parallel/passes/meta_info_prop.py @@ -1,15 +1,14 @@ import uuid from dataclasses import asdict -from typing import Any, Dict, List, NamedTuple, Tuple +from typing import List import torch import torch.fx from torch.fx import GraphModule -from torch.fx.node import Argument, Node, Target -from torch.utils._pytree import tree_map +from torch.fx.node import Node from colossalai.auto_parallel.meta_profiler import MetaInfo -from colossalai.fx._compatibility import compatibility, is_compatible_with_meta +from colossalai.fx._compatibility import compatibility from colossalai.fx.profiler import GraphInfo from colossalai.fx.profiler.constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS From 5a44d244e43380ee852432fbb48ddd0b1d270a45 Mon Sep 17 00:00:00 2001 From: Cypher30 <1529318642@qq.com> Date: Mon, 2 Jan 2023 10:05:38 +0800 Subject: [PATCH 6/6] [autoparallel] hook bwd_mem_out --- colossalai/auto_parallel/passes/meta_info_prop.py | 1 + 1 file changed, 1 insertion(+) diff --git a/colossalai/auto_parallel/passes/meta_info_prop.py b/colossalai/auto_parallel/passes/meta_info_prop.py index 037165756ee0..607f7e17ec73 100644 --- a/colossalai/auto_parallel/passes/meta_info_prop.py +++ b/colossalai/auto_parallel/passes/meta_info_prop.py @@ -157,6 +157,7 @@ def node_handler(self, node: Node) -> None: memory_cost = meta_info.memory_cost graph_info.fwd_mem_tmp = memory_cost.fwd.temp graph_info.bwd_mem_tmp = memory_cost.bwd.temp + graph_info.bwd_mem_out = memory_cost.bwd.activation # fetch flop information # here we use fwd_time and bwd_time to deal with the case that