From 354f04ee251413d2ec6b81bd8ed2f5f9551f44b9 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Mon, 2 Jan 2023 23:29:06 +0800 Subject: [PATCH 1/3] [autockpt] make it work. --- .../auto_parallel/passes/comm_metainfo_pass.py | 14 +++++++------- .../node_handler/binary_elementwise_handler.py | 2 +- .../tensor_shard/node_handler/reshape_handler.py | 4 ++-- .../node_handler/unary_elementwise_handler.py | 4 ++-- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/colossalai/auto_parallel/passes/comm_metainfo_pass.py b/colossalai/auto_parallel/passes/comm_metainfo_pass.py index 5ab6289b7de7..ab3acb0563ff 100644 --- a/colossalai/auto_parallel/passes/comm_metainfo_pass.py +++ b/colossalai/auto_parallel/passes/comm_metainfo_pass.py @@ -54,7 +54,7 @@ def _construct_meta_info(node: Node, origin_sharding_spec: ShardingSpec, return meta_info -def _runtime_apply_meta_info(node: Node, original_sharding_spec_dict, sharding_spec_dict) -> MetaInfo: +def _runtime_apply_meta_info(node: Node, origin_spec_dict, sharding_spec_dict) -> MetaInfo: """ This method is used to construct `MetaInto` for shape consistency node """ @@ -62,8 +62,8 @@ def _runtime_apply_meta_info(node: Node, original_sharding_spec_dict, sharding_s # extract node index and user node index args = node.args node_index, user_node_index = args[3], args[4] - origin_sharding_spec, target_sharding_spec = original_sharding_spec_dict[node_index], sharding_spec_dict[ - node_index][user_node_index] + origin_sharding_spec, target_sharding_spec = origin_spec_dict[node_index], sharding_spec_dict[node_index][ + user_node_index] return _construct_meta_info(node, origin_sharding_spec, target_sharding_spec) @@ -98,16 +98,16 @@ def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> M return meta_info -def comm_metainfo_pass(gm: GraphModule, sharding_spec_dict: Dict, original_sharding_spec_dict: Dict, - comm_actions_dict: Dict): +def comm_metainfo_pass(gm: GraphModule, sharding_spec_dict: Dict, origin_spec_dict: Dict, + comm_actions_dict: Dict) -> GraphModule: """ The method manages all the metainfo of the communication node (run_time_apply, runtime_comm_spec_apply) in the graph. """ for node in gm.graph.nodes: if node.target == runtime_apply: - setattr(node, 'best_metainfo', - _runtime_apply_meta_info(node, original_sharding_spec_dict, sharding_spec_dict)) + setattr(node, 'best_metainfo', _runtime_apply_meta_info(node, origin_spec_dict, sharding_spec_dict)) elif node.target == runtime_comm_spec_apply: setattr(node, 'best_metainfo', _runtime_comm_spec_apply_meta_info(node, comm_actions_dict)) else: pass + return gm diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py index e8ae363e97a1..f510f74776b6 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py @@ -16,7 +16,7 @@ @operator_registry.register(BCAST_FUNC_OP) -class BinaryElementwiseHandler(NodeHandler): +class BinaryElementwiseHandler(MetaInfoNodeHandler): """ An BinaryBcastOpHandler is a node handler which deals with operations which have two operands and broadcasting occurs such as torch.add. diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py index b463487165cb..7763b1884025 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py @@ -3,7 +3,7 @@ import torch from ..sharding_strategy import OperationData, OperationDataType -from .node_handler import NodeHandler +from .node_handler import MetaInfoNodeHandler, NodeHandler from .registry import operator_registry from .strategy import ReshapeGenerator, StrategyGenerator @@ -13,7 +13,7 @@ @operator_registry.register(torch.flatten) @operator_registry.register(torch.Tensor.unsqueeze) @operator_registry.register(torch.nn.AdaptiveAvgPool2d) -class ReshapeHandler(NodeHandler): +class ReshapeHandler(MetaInfoNodeHandler): """ A ReshapeHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape. """ diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py index bda160906517..0362de780d7a 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py @@ -3,7 +3,7 @@ import torch from ..sharding_strategy import OperationData, OperationDataType -from .node_handler import NodeHandler +from .node_handler import MetaInfoNodeHandler, NodeHandler from .registry import operator_registry from .strategy import StrategyGenerator, UnaryElementwiseGenerator @@ -19,7 +19,7 @@ @operator_registry.register(torch.nn.modules.dropout.Dropout) @operator_registry.register(torch.Tensor.contiguous) @operator_registry.register(torch.nn.functional.dropout) -class UnaryElementwiseHandler(NodeHandler): +class UnaryElementwiseHandler(MetaInfoNodeHandler): """ A UnaryElementwiseHandler which deals with the sharding strategies for UnaryElementwise Op. """ From 47446300b069ced051892994632e4e423fd43cc8 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Tue, 3 Jan 2023 14:46:04 +0800 Subject: [PATCH 2/3] [autockpt] linearize / merge shape-consistency nodes. --- .../checkpoint/ckpt_solver_base.py | 24 +++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py b/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py index 63eff31b2da7..ecccef8d7620 100644 --- a/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py +++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py @@ -5,8 +5,12 @@ import torch from torch.fx import Graph, Node +from colossalai.auto_parallel.passes.runtime_apply_pass import ( + runtime_apply, + runtime_apply_for_iterable_object, + runtime_comm_spec_apply, +) from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen -from colossalai.fx.profiler.memory_utils import is_inplace __all___ = ['CheckpointSolverBase'] @@ -131,7 +135,23 @@ def _is_sink() -> bool: bool """ - return not sum([v for _, v in deps.items()]) and not any(map(is_inplace, n.users)) + def _is_inplace(n: Node): + """Get the inplace argument from torch.fx.Node + """ + inplace = False + if n.op == "call_function": + inplace = n.kwargs.get("inplace", False) + elif n.op == "call_module": + inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False) + return inplace + + def _is_shape_consistency(n: Node): + """Check if this node is shape-consistency node (i.e. ``runtime_apply`` or ``runtime_apply_for_iterable_object``) + """ + return n.target in [runtime_apply, runtime_apply_for_iterable_object, runtime_comm_spec_apply] + + return not sum([v for _, v in deps.items()]) and not any(map(_is_inplace, n.users)) and not any( + map(_is_shape_consistency, n.users)) # make sure that item in cnode is valid if self.cnode: From 9ea5c1b33421cbd596bf104e7f1bfaaf78202440 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Tue, 3 Jan 2023 16:45:30 +0800 Subject: [PATCH 3/3] [autockpt] considering parameter and optimizer weights. --- .../checkpoint/ckpt_solver_base.py | 24 +++++++++++-------- .../checkpoint/ckpt_solver_chen.py | 6 ++--- .../checkpoint/ckpt_solver_rotor.py | 19 ++++++++++----- 3 files changed, 30 insertions(+), 19 deletions(-) diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py b/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py index ecccef8d7620..b388d00ac553 100644 --- a/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py +++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py @@ -35,10 +35,11 @@ def __init__( free_memory: float = -1.0, requires_linearize: bool = False, cnode: List[str] = None, + optim_multiplier: float = 1.0, ): - """CheckpointSolver class will integrate information provided by the components - and use an existing solver to find a possible optimal strategies combination for - target computing graph. + """``CheckpointSolverBase`` class will integrate information provided by the components + and use an existing solver to find a possible optimal strategies combination for target + computing graph. Existing Solvers: Chen's Greedy solver: https://arxiv.org/abs/1604.06174 (CheckpointSolverChen) @@ -49,9 +50,11 @@ def __init__( free_memory (float): Memory constraint for the solution. requires_linearize (bool): Whether the graph needs to be linearized. cnode (List[str], optional): Common node List, should be the subset of input. Default to None. + optim_multiplier (float, optional): The multiplier of extra weight storage for the + ``torch.optim.Optimizer``. Default to 1.0. Warnings: - `MetaInfoProp` should be done before constructing the solver. Meta information of the graph is required. + Meta information of the graph is required for any ``CheckpointSolver``. """ # super-dainiu: this graph is a temporary graph which can refer to # the owning module, but we will return another deepcopy of it after @@ -61,13 +64,14 @@ def __init__( _copy_output(graph, self.graph) self.graph.set_codegen(ActivationCheckpointCodeGen()) - # check if `MetaInfoProp` is done + # check if has meta information if any(len(node.meta) == 0 for node in self.graph.nodes): raise RuntimeError( - "Nodes meta information hasn't been prepared! Please run MetaInfoProp before constructing the solver!") + "Nodes meta information hasn't been prepared! Please extract from graph before constructing the solver!" + ) - self.free_memory = free_memory - self.parameter_size = _get_param_size(self.graph.owning_module) + # parameter memory = parameter size + optimizer extra weight storage + self.free_memory = free_memory - _get_param_size(self.graph.owning_module) * (optim_multiplier + 1) self.cnode = cnode self.requires_linearize = requires_linearize if self.requires_linearize: @@ -97,7 +101,7 @@ def _linearize_graph(self) -> List[List[Node]]: the actual 'node' in linearized manner. Remarks: - Do merge the inplace ops into the previous node. + Do merge the inplace ops and shape-consistency ops into the previous node. """ # Common nodes are type of nodes that could be seen as attributes and remain @@ -136,7 +140,7 @@ def _is_sink() -> bool: """ def _is_inplace(n: Node): - """Get the inplace argument from torch.fx.Node + """Get the inplace argument from ``torch.fx.Node`` """ inplace = False if n.op == "call_function": diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_chen.py b/colossalai/auto_parallel/checkpoint/ckpt_solver_chen.py index 58878253e99e..19b2ef5987c9 100644 --- a/colossalai/auto_parallel/checkpoint/ckpt_solver_chen.py +++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_chen.py @@ -19,9 +19,9 @@ def __init__(self, graph: Graph, cnode: List[str] = None, num_grids: int = 6): Note that this algorithm targets at memory optimization only, using techniques in appendix A. Usage: - Assume that we have a `GraphModule`, and we already applied the `MetaInfoProp` + Assume that we have a ``GraphModule``, and we have already done the extractions to the graph to retrieve all information needed, then we could use the following - code to find a solution using `CheckpointSolverChen`: + code to find a solution using ``CheckpointSolverChen``: >>> solver = CheckpointSolverChen(gm.graph) >>> chen_graph = solver.solve() >>> gm.graph = chen_graph # set the graph to a new graph @@ -74,7 +74,7 @@ def run_chen_greedy(self, b: int = 0) -> Tuple[Set, int]: def grid_search(self) -> Set: """ Search ckpt strategy with b = 0, then run the allocation algorithm again with b = √xy. - Grid search over [√2/2 b, √2 b] for ckpt_opt over num_grids as in appendix A. + Grid search over [√2/2 b, √2 b] for ``ckpt_opt`` over ``num_grids`` as in appendix A. """ _, b_approx = self.run_chen_greedy(0) b_min, b_max = math.floor(b_approx / math.sqrt(2)), math.ceil(b_approx * math.sqrt(2)) diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py b/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py index cd5b70d110dc..5cc57fca0cff 100644 --- a/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py +++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py @@ -23,15 +23,20 @@ class CheckpointSolverRotor(CheckpointSolverBase): - def __init__(self, graph: Graph, free_memory: float = -1, cnode: List[str] = None, memory_slots: int = 500): + def __init__(self, + graph: Graph, + free_memory: float = -1, + cnode: List[str] = None, + memory_slots: int = 500, + optim_multiplier: float = 1.0): """This is the simple implementation of dynamic programming algorithm rotor in https://hal.inria.fr/hal-02352969. Some code are adapted from https://gitlab.inria.fr/hiepacs/rotor. Usage: - Assume that we have a `GraphModule`, and we already applied the `MetaInfoProp` + Assume that we have a ``GraphModule``, and we have already done the extractions to the graph to retrieve all information needed, then we could use the following - code to find a solution using `CheckpointSolverRotor`: + code to find a solution using ``CheckpointSolverRotor``: >>> solver = CheckpointSolverRotor(gm.graph, free_memory=torch.cuda.mem_get_info(device=0)[0]) >>> rotor_graph = solver.solve(force_python=True) # otherwise use C solver >>> gm.graph = rotor_graph # set the graph to a new graph @@ -42,6 +47,8 @@ def __init__(self, graph: Graph, free_memory: float = -1, cnode: List[str] = Non Use ``torch.cuda.mem_get_info(device=0)[0]`` to estimate the free_memory. Defaults to -1. cnode (List[str], optional): Common node List, should be the subset of input. Defaults to None. memory_slots (int, optional): Number of slots for discretizing memory budget. Defaults to 500. + optim_multiplier (float, optional): The multiplier of extra weight storage for the + ``torch.optim.Optimizer``. Default to 1.0. """ super().__init__(graph, free_memory, True, cnode) self.memory_slots = memory_slots @@ -298,8 +305,8 @@ def _backtrack(chain: Chain, lhs: int, rhs: int, budget: int, cost_table: List[A lhs (int): The left index of the interval to backtrack. rhs (int): The right index of the interval to backtrack. budget (int): The memory budget for processing this interval. - cost_table (List[Any]): See `._compute_table()` for definitions - back_ptr (List[Any]): See `._compute_table()` for definitions + cost_table (List[Any]): See ``._compute_table()`` for definitions + back_ptr (List[Any]): See ``._compute_table()`` for definitions Raises: ValueError: Can not process the chain. @@ -340,7 +347,7 @@ def _backtrack(chain: Chain, lhs: int, rhs: int, budget: int, cost_table: List[A @staticmethod def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]): - """Annotate the nodes in the node_list with activation checkpoint from the sequence. + """Annotate the nodes in the ``node_list`` with activation checkpoint from the sequence. Args: sequence (Sequence): The sequence of executing nodes with activation checkpoint annotations.