Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions colossalai/auto_parallel/checkpoint/ckpt_solver_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,11 @@ def __init__(
free_memory: float = -1.0,
requires_linearize: bool = False,
cnode: List[str] = None,
optim_multiplier: float = 1.0,
):
"""CheckpointSolver class will integrate information provided by the components
and use an existing solver to find a possible optimal strategies combination for
target computing graph.
"""``CheckpointSolverBase`` class will integrate information provided by the components
and use an existing solver to find a possible optimal strategies combination for target
computing graph.

Existing Solvers:
Chen's Greedy solver: https://arxiv.org/abs/1604.06174 (CheckpointSolverChen)
Expand All @@ -49,9 +50,11 @@ def __init__(
free_memory (float): Memory constraint for the solution.
requires_linearize (bool): Whether the graph needs to be linearized.
cnode (List[str], optional): Common node List, should be the subset of input. Default to None.
optim_multiplier (float, optional): The multiplier of extra weight storage for the
``torch.optim.Optimizer``. Default to 1.0.

Warnings:
`MetaInfoProp` should be done before constructing the solver. Meta information of the graph is required.
Meta information of the graph is required for any ``CheckpointSolver``.
"""
# super-dainiu: this graph is a temporary graph which can refer to
# the owning module, but we will return another deepcopy of it after
Expand All @@ -61,13 +64,14 @@ def __init__(
_copy_output(graph, self.graph)
self.graph.set_codegen(ActivationCheckpointCodeGen())

# check if `MetaInfoProp` is done
# check if has meta information
if any(len(node.meta) == 0 for node in self.graph.nodes):
raise RuntimeError(
"Nodes meta information hasn't been prepared! Please run MetaInfoProp before constructing the solver!")
"Nodes meta information hasn't been prepared! Please extract from graph before constructing the solver!"
)

self.free_memory = free_memory
self.parameter_size = _get_param_size(self.graph.owning_module)
# parameter memory = parameter size + optimizer extra weight storage
self.free_memory = free_memory - _get_param_size(self.graph.owning_module) * (optim_multiplier + 1)
self.cnode = cnode
self.requires_linearize = requires_linearize
if self.requires_linearize:
Expand Down Expand Up @@ -97,7 +101,7 @@ def _linearize_graph(self) -> List[List[Node]]:
the actual 'node' in linearized manner.

Remarks:
Do merge the inplace ops into the previous node.
Do merge the inplace ops and shape-consistency ops into the previous node.
"""

# Common nodes are type of nodes that could be seen as attributes and remain
Expand Down Expand Up @@ -136,7 +140,7 @@ def _is_sink() -> bool:
"""

def _is_inplace(n: Node):
"""Get the inplace argument from torch.fx.Node
"""Get the inplace argument from ``torch.fx.Node``
"""
inplace = False
if n.op == "call_function":
Expand Down
6 changes: 3 additions & 3 deletions colossalai/auto_parallel/checkpoint/ckpt_solver_chen.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ def __init__(self, graph: Graph, cnode: List[str] = None, num_grids: int = 6):
Note that this algorithm targets at memory optimization only, using techniques in appendix A.

Usage:
Assume that we have a `GraphModule`, and we already applied the `MetaInfoProp`
Assume that we have a ``GraphModule``, and we have already done the extractions
to the graph to retrieve all information needed, then we could use the following
code to find a solution using `CheckpointSolverChen`:
code to find a solution using ``CheckpointSolverChen``:
>>> solver = CheckpointSolverChen(gm.graph)
>>> chen_graph = solver.solve()
>>> gm.graph = chen_graph # set the graph to a new graph
Expand Down Expand Up @@ -74,7 +74,7 @@ def run_chen_greedy(self, b: int = 0) -> Tuple[Set, int]:
def grid_search(self) -> Set:
"""
Search ckpt strategy with b = 0, then run the allocation algorithm again with b = √xy.
Grid search over [√2/2 b, √2 b] for ckpt_opt over num_grids as in appendix A.
Grid search over [√2/2 b, √2 b] for ``ckpt_opt`` over ``num_grids`` as in appendix A.
"""
_, b_approx = self.run_chen_greedy(0)
b_min, b_max = math.floor(b_approx / math.sqrt(2)), math.ceil(b_approx * math.sqrt(2))
Expand Down
19 changes: 13 additions & 6 deletions colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,20 @@

class CheckpointSolverRotor(CheckpointSolverBase):

def __init__(self, graph: Graph, free_memory: float = -1, cnode: List[str] = None, memory_slots: int = 500):
def __init__(self,
graph: Graph,
free_memory: float = -1,
cnode: List[str] = None,
memory_slots: int = 500,
optim_multiplier: float = 1.0):
"""This is the simple implementation of dynamic programming algorithm rotor
in https://hal.inria.fr/hal-02352969. Some code are adapted from
https://gitlab.inria.fr/hiepacs/rotor.

Usage:
Assume that we have a `GraphModule`, and we already applied the `MetaInfoProp`
Assume that we have a ``GraphModule``, and we have already done the extractions
to the graph to retrieve all information needed, then we could use the following
code to find a solution using `CheckpointSolverRotor`:
code to find a solution using ``CheckpointSolverRotor``:
>>> solver = CheckpointSolverRotor(gm.graph, free_memory=torch.cuda.mem_get_info(device=0)[0])
>>> rotor_graph = solver.solve(force_python=True) # otherwise use C solver
>>> gm.graph = rotor_graph # set the graph to a new graph
Expand All @@ -42,6 +47,8 @@ def __init__(self, graph: Graph, free_memory: float = -1, cnode: List[str] = Non
Use ``torch.cuda.mem_get_info(device=0)[0]`` to estimate the free_memory. Defaults to -1.
cnode (List[str], optional): Common node List, should be the subset of input. Defaults to None.
memory_slots (int, optional): Number of slots for discretizing memory budget. Defaults to 500.
optim_multiplier (float, optional): The multiplier of extra weight storage for the
``torch.optim.Optimizer``. Default to 1.0.
"""
super().__init__(graph, free_memory, True, cnode)
self.memory_slots = memory_slots
Expand Down Expand Up @@ -298,8 +305,8 @@ def _backtrack(chain: Chain, lhs: int, rhs: int, budget: int, cost_table: List[A
lhs (int): The left index of the interval to backtrack.
rhs (int): The right index of the interval to backtrack.
budget (int): The memory budget for processing this interval.
cost_table (List[Any]): See `._compute_table()` for definitions
back_ptr (List[Any]): See `._compute_table()` for definitions
cost_table (List[Any]): See ``._compute_table()`` for definitions
back_ptr (List[Any]): See ``._compute_table()`` for definitions

Raises:
ValueError: Can not process the chain.
Expand Down Expand Up @@ -340,7 +347,7 @@ def _backtrack(chain: Chain, lhs: int, rhs: int, budget: int, cost_table: List[A

@staticmethod
def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]):
"""Annotate the nodes in the node_list with activation checkpoint from the sequence.
"""Annotate the nodes in the ``node_list`` with activation checkpoint from the sequence.

Args:
sequence (Sequence): The sequence of executing nodes with activation checkpoint annotations.
Expand Down