From 354f04ee251413d2ec6b81bd8ed2f5f9551f44b9 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Mon, 2 Jan 2023 23:29:06 +0800 Subject: [PATCH 1/2] [autockpt] make it work. --- .../auto_parallel/passes/comm_metainfo_pass.py | 14 +++++++------- .../node_handler/binary_elementwise_handler.py | 2 +- .../tensor_shard/node_handler/reshape_handler.py | 4 ++-- .../node_handler/unary_elementwise_handler.py | 4 ++-- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/colossalai/auto_parallel/passes/comm_metainfo_pass.py b/colossalai/auto_parallel/passes/comm_metainfo_pass.py index 5ab6289b7de7..ab3acb0563ff 100644 --- a/colossalai/auto_parallel/passes/comm_metainfo_pass.py +++ b/colossalai/auto_parallel/passes/comm_metainfo_pass.py @@ -54,7 +54,7 @@ def _construct_meta_info(node: Node, origin_sharding_spec: ShardingSpec, return meta_info -def _runtime_apply_meta_info(node: Node, original_sharding_spec_dict, sharding_spec_dict) -> MetaInfo: +def _runtime_apply_meta_info(node: Node, origin_spec_dict, sharding_spec_dict) -> MetaInfo: """ This method is used to construct `MetaInto` for shape consistency node """ @@ -62,8 +62,8 @@ def _runtime_apply_meta_info(node: Node, original_sharding_spec_dict, sharding_s # extract node index and user node index args = node.args node_index, user_node_index = args[3], args[4] - origin_sharding_spec, target_sharding_spec = original_sharding_spec_dict[node_index], sharding_spec_dict[ - node_index][user_node_index] + origin_sharding_spec, target_sharding_spec = origin_spec_dict[node_index], sharding_spec_dict[node_index][ + user_node_index] return _construct_meta_info(node, origin_sharding_spec, target_sharding_spec) @@ -98,16 +98,16 @@ def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> M return meta_info -def comm_metainfo_pass(gm: GraphModule, sharding_spec_dict: Dict, original_sharding_spec_dict: Dict, - comm_actions_dict: Dict): +def comm_metainfo_pass(gm: GraphModule, sharding_spec_dict: Dict, origin_spec_dict: Dict, + comm_actions_dict: Dict) -> GraphModule: """ The method manages all the metainfo of the communication node (run_time_apply, runtime_comm_spec_apply) in the graph. """ for node in gm.graph.nodes: if node.target == runtime_apply: - setattr(node, 'best_metainfo', - _runtime_apply_meta_info(node, original_sharding_spec_dict, sharding_spec_dict)) + setattr(node, 'best_metainfo', _runtime_apply_meta_info(node, origin_spec_dict, sharding_spec_dict)) elif node.target == runtime_comm_spec_apply: setattr(node, 'best_metainfo', _runtime_comm_spec_apply_meta_info(node, comm_actions_dict)) else: pass + return gm diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py index e8ae363e97a1..f510f74776b6 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py @@ -16,7 +16,7 @@ @operator_registry.register(BCAST_FUNC_OP) -class BinaryElementwiseHandler(NodeHandler): +class BinaryElementwiseHandler(MetaInfoNodeHandler): """ An BinaryBcastOpHandler is a node handler which deals with operations which have two operands and broadcasting occurs such as torch.add. diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py index b463487165cb..7763b1884025 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py @@ -3,7 +3,7 @@ import torch from ..sharding_strategy import OperationData, OperationDataType -from .node_handler import NodeHandler +from .node_handler import MetaInfoNodeHandler, NodeHandler from .registry import operator_registry from .strategy import ReshapeGenerator, StrategyGenerator @@ -13,7 +13,7 @@ @operator_registry.register(torch.flatten) @operator_registry.register(torch.Tensor.unsqueeze) @operator_registry.register(torch.nn.AdaptiveAvgPool2d) -class ReshapeHandler(NodeHandler): +class ReshapeHandler(MetaInfoNodeHandler): """ A ReshapeHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape. """ diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py index bda160906517..0362de780d7a 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py @@ -3,7 +3,7 @@ import torch from ..sharding_strategy import OperationData, OperationDataType -from .node_handler import NodeHandler +from .node_handler import MetaInfoNodeHandler, NodeHandler from .registry import operator_registry from .strategy import StrategyGenerator, UnaryElementwiseGenerator @@ -19,7 +19,7 @@ @operator_registry.register(torch.nn.modules.dropout.Dropout) @operator_registry.register(torch.Tensor.contiguous) @operator_registry.register(torch.nn.functional.dropout) -class UnaryElementwiseHandler(NodeHandler): +class UnaryElementwiseHandler(MetaInfoNodeHandler): """ A UnaryElementwiseHandler which deals with the sharding strategies for UnaryElementwise Op. """ From 47446300b069ced051892994632e4e423fd43cc8 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Tue, 3 Jan 2023 14:46:04 +0800 Subject: [PATCH 2/2] [autockpt] linearize / merge shape-consistency nodes. --- .../checkpoint/ckpt_solver_base.py | 24 +++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py b/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py index 63eff31b2da7..ecccef8d7620 100644 --- a/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py +++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py @@ -5,8 +5,12 @@ import torch from torch.fx import Graph, Node +from colossalai.auto_parallel.passes.runtime_apply_pass import ( + runtime_apply, + runtime_apply_for_iterable_object, + runtime_comm_spec_apply, +) from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen -from colossalai.fx.profiler.memory_utils import is_inplace __all___ = ['CheckpointSolverBase'] @@ -131,7 +135,23 @@ def _is_sink() -> bool: bool """ - return not sum([v for _, v in deps.items()]) and not any(map(is_inplace, n.users)) + def _is_inplace(n: Node): + """Get the inplace argument from torch.fx.Node + """ + inplace = False + if n.op == "call_function": + inplace = n.kwargs.get("inplace", False) + elif n.op == "call_module": + inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False) + return inplace + + def _is_shape_consistency(n: Node): + """Check if this node is shape-consistency node (i.e. ``runtime_apply`` or ``runtime_apply_for_iterable_object``) + """ + return n.target in [runtime_apply, runtime_apply_for_iterable_object, runtime_comm_spec_apply] + + return not sum([v for _, v in deps.items()]) and not any(map(_is_inplace, n.users)) and not any( + map(_is_shape_consistency, n.users)) # make sure that item in cnode is valid if self.cnode: